288 lines
14 KiB
Python
288 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
|
|
#
|
|
# bfloat16nncore.py
|
|
#
|
|
# bfloat16 neural network processing core logic
|
|
#
|
|
# History:
|
|
# --------
|
|
# 21.04.21/KQ Initial version
|
|
#
|
|
|
|
from migen import *
|
|
from migen.fhdl.specials import Memory
|
|
from litex.soc.interconnect.csr import AutoCSR, CSRStatus, CSRStorage, CSRField, CSRAccess
|
|
from litex.soc.integration.doc import AutoDoc, ModuleDoc
|
|
|
|
from litex.soc.interconnect.csr import *
|
|
|
|
from libmodules.dramtransfer import DRAM2FPGA, FPGA2DRAM
|
|
from libmodules.bfloat16processor import bfloat16Processor
|
|
|
|
class bfloat16NeuralNetworkCore(Module, AutoCSR, AutoDoc, ModuleDoc):
|
|
"""
|
|
bfloat16NN core class:
|
|
|
|
Usage:
|
|
######
|
|
|
|
1. Freeze operations by setting ``bEnable`` to false (0)
|
|
|
|
2. Load ``b32DRAMAddress`` with a 32-bit DRAM memory pointer.
|
|
|
|
3. Finally, enable processing by setting ``bEnable`` to true (1).
|
|
|
|
4. bfloat16NN processing will now run async. to CPU until reset or s.a. (1.).
|
|
CPU from now on will write to DRAM only (yet, L2 cache will have to be flushed),
|
|
all data will be picked up by FPGA automatically ...
|
|
|
|
Inputs:
|
|
#######
|
|
|
|
:b32DRAMAddress: New DRAM address from where to load into local memory
|
|
|
|
:b32Sentinel: Write control word to last address (same as [b32DRAMAddress+511] value)
|
|
|
|
:bEnable: To enable running (after data preparation)
|
|
|
|
:b9ArrayWordLen: Number of words used for calculation of scalar (inner) product
|
|
|
|
Outputs:
|
|
########
|
|
|
|
:b16Result: Processing result
|
|
|
|
:bReady: Ready indication (wire to LED ... ;)
|
|
|
|
"""
|
|
def __init__(self, RAMWaitTime=128, LUCacheSize=8, LoadUnit=None, StoreUnit=None):
|
|
|
|
# Inputs
|
|
self.b32DRAMLoadAddress = CSRStorage(32, reset_less=False,
|
|
fields=[CSRField("LoadAddress", size=32, description="*Field*: 32-Bit value")],
|
|
description="""
|
|
Load value (32-bit DRAM address).
|
|
""")
|
|
self.b32Sentinel = CSRStorage(32, reset_less=False,
|
|
fields=[CSRField("Sentinel", size=32, description="*Field*: 32-Bit value")],
|
|
description="""
|
|
Control value
|
|
""")
|
|
self.bEnable = CSRStorage(1, reset_less=False,
|
|
fields=[CSRField("Enable", size=1, description="*Field*: bit", values=[
|
|
("0", "DISABLED", "bfloat16NN not active"),
|
|
("1", "ENABLED", "bfloat16NN active"),
|
|
])
|
|
],
|
|
description="""
|
|
Enable free run
|
|
""")
|
|
self.b9ArrayWordLen = CSRStorage(9, reset_less=False,
|
|
fields=[CSRField("ArrayWordLen", size=9, description="*Field*: 9-Bit value")],
|
|
description="""
|
|
Word length of array used for calculation
|
|
""")
|
|
|
|
# Outputs
|
|
self.b16Status = CSRStorage(16, reset_less=False,
|
|
fields=[CSRField("Status", size=16, description="*Field*: 16-Bit value")],
|
|
description="""
|
|
Processing stati
|
|
""")
|
|
self.b16FPUStates = CSRStorage(16, reset_less=False,
|
|
fields=[CSRField("FPUStates", size=16, description="*Field*: 16-Bit value")],
|
|
description="""
|
|
FPU states: Low FPU#1, High FPU#2
|
|
""")
|
|
""" TODO: Remove!
|
|
self.b16Value1_1 = CSRStorage(16, reset_less=False,
|
|
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
|
|
description="
|
|
FPU#1 Float register 1
|
|
")
|
|
self.b16Value1_2 = CSRStorage(16, reset_less=False,
|
|
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
|
|
description="
|
|
#FPU#1 Float register 2
|
|
")
|
|
self.b16Value2_1 = CSRStorage(16, reset_less=False,
|
|
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
|
|
description="
|
|
#FPU#2 Float register 1
|
|
")
|
|
self.b16Value2_2 = CSRStorage(16, reset_less=False,
|
|
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
|
|
description=" " "
|
|
#FPU#2 Float register 2
|
|
" " ")
|
|
"""
|
|
self.b16Result1 = CSRStorage(16, reset_less=False,
|
|
fields=[CSRField("Result1", size=16, description="*Field*: 16-Bit value")],
|
|
description="""
|
|
FPU#1 Processing result
|
|
""")
|
|
self.b16Result2 = CSRStorage(16, reset_less=False,
|
|
fields=[CSRField("Result2", size=16, description="*Field*: 16-Bit value")],
|
|
description="""
|
|
FPU#2 Processing result
|
|
""")
|
|
self.bReady = Signal() # To be wired to data pin ... ;)
|
|
|
|
# Local vars.
|
|
# - none yet -
|
|
|
|
#---------------- Load unit (LU) -------------------------------------------------------------
|
|
LU_fsm = FSM(reset_state="LU_IDLE") # FSM starts idling ...
|
|
self.submodules += LU_fsm
|
|
|
|
self.LU_CacheOffset = Signal(9, reset_less=True) # 0..511 log2_int(LUCacheSize, False)) # Cache reading offset (0..(Size-1))=>Bits)
|
|
self.LU_CacheValid = Signal() # Indicate loaded LU cache
|
|
self.LU_CacheDelay = Signal(11, reset_less=True) # Evaluate load length in cycles (2048 max.)
|
|
LU_fsm.act("LU_IDLE", # If cache not valid fill it!
|
|
If(~self.LU_CacheValid & self.bEnable.storage, # Invalid cache & run enabled ...
|
|
NextValue(LoadUnit.b32Address.storage, self.b32DRAMLoadAddress.storage),
|
|
NextValue(self.LU_CacheOffset, 0), # Adjust pointer (local reader), 4-byte width=32-bit
|
|
NextValue(self.LU_CacheDelay, 2), # Reset load delay counter (but inkl. 1st & last cycle)
|
|
NextState("LU_LOAD1")
|
|
).Elif(~self.bEnable.storage, # Cleared enable?
|
|
NextValue(self.LU_CacheValid, False), # Enforce cache invalidation!
|
|
)
|
|
)
|
|
LU_fsm.act("LU_LOAD1", # Engage!
|
|
NextValue(LoadUnit.bEnable.storage, 1), # Trigger DRAM transfer to cache
|
|
NextState("LU_LOAD2")
|
|
)
|
|
LU_fsm.act("LU_LOAD2", # Wait for termination of transfer ...
|
|
If(LoadUnit.bValid.storage, # Data avail.?
|
|
NextValue(self.LU_CacheValid, 1), # Declare cache valid
|
|
NextValue(LoadUnit.bEnable.storage, 0), # Stop DRAM transfer to cache
|
|
NextState("LU_IDLE") # Yap!
|
|
).Else(
|
|
If(self.LU_CacheDelay < 2047, # MAX-1!
|
|
NextValue(self.LU_CacheDelay, self.LU_CacheDelay + 1),
|
|
)
|
|
# TODO: Permit timeout indication ...
|
|
)
|
|
)
|
|
|
|
#---------------- bfloat16 FPUs -------------------------------------------------------------
|
|
NFPUCORES=2 # No. of FPUs used
|
|
self.submodules.fpu1 = fpu1 = bfloat16Processor() # Integrate bfloat16 FPU
|
|
self.submodules.fpu2 = fpu2 = bfloat16Processor() # Integrate another one!
|
|
|
|
#---------------- Loaded data testing --------------------------------------------------
|
|
Loader_fsm = FSM(reset_state="Loader_IDLE") # FSM starts idling ...
|
|
self.submodules += Loader_fsm
|
|
|
|
self.Loader_Delay = Signal(32, reset_less=True)
|
|
self.Loader_Active = Signal()
|
|
Loader_fsm.act("Loader_IDLE",
|
|
If(self.LU_CacheValid & ~self.Loader_Active, # Enter if not active already
|
|
NextValue(self.Loader_Active, True), # Loader up & running
|
|
NextValue(self.Loader_Delay, 0), # Reset read delay timer
|
|
NextValue(LoadUnit.b9Offset1.storage, LUCacheSize - 1), # Adjust offset to read sentinel
|
|
NextValue(LoadUnit.b9Offset2.storage, LUCacheSize >> 1), # Adjust offset to start of 2nd array
|
|
NextValue(self.b16Result1.storage, 0), # Indicate # delays
|
|
NextValue(self.b16Result2.storage, 0), # Indicate # delays
|
|
#NextValue(self.b16Value1_1.storage, 0), # TODO: Remove! Nothing loaded so far ...
|
|
#NextValue(self.b16Value1_2.storage, 0),
|
|
#NextValue(self.b16Value2_1.storage, 0),
|
|
#NextValue(self.b16Value2_2.storage, 0),
|
|
NextValue(self.bReady, False), # LED off!
|
|
NextState("Loader_LOAD1")
|
|
).Elif(~self.bEnable.storage, # Externally aborted?
|
|
NextValue(self.b16Status.storage, 0), # Current status: inactive
|
|
NextValue(self.Loader_Active, False), # Reset in sync w/ global activation
|
|
)
|
|
)
|
|
Loader_fsm.act("Loader_LOAD1",
|
|
NextValue(self.b16Status.storage[0], True), # Current status added
|
|
If(LoadUnit.b32Data1.storage == self.b32Sentinel.storage, # Valid last entry?
|
|
NextValue(LoadUnit.b9Offset1.storage, 0), # 1st value offset preparation
|
|
NextState("Loader_LOAD2")
|
|
).Elif(~self.bEnable.storage, # Enable withdrawn?
|
|
NextState("Loader_IDLE") # Abort!
|
|
)
|
|
)
|
|
|
|
#-----> LOOP ENTRY ! (2nd loop onward: fs3 already prepared!)
|
|
Loader_fsm.act("Loader_LOAD2",
|
|
NextValue(self.b16Status.storage[1], True), # Current status added
|
|
If(self.Loader_Delay > RAMWaitTime, # Required only for 1st entry ...
|
|
# FPU#1
|
|
#NextValue(self.b16Value1_1.storage, LoadUnit.b32Data1.storage & 0xFFFF), # TODO: Remove! Pick 1st date
|
|
#NextValue(self.b16Value1_2.storage, LoadUnit.b32Data1.storage >> 16), # TODO: Remove! Pick 2nd date
|
|
NextValue(fpu1.fs1, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data1.storage[0:16])),
|
|
NextValue(fpu1.fs2, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data1.storage[16:32])),
|
|
NextValue(LoadUnit.b9Offset1.storage, LoadUnit.b9Offset1.storage + 1), # Move on to next entry
|
|
# FPU#2
|
|
#NextValue(self.b16Value2_1.storage, LoadUnit.b32Data2.storage & 0xFFFF), # TODO: Remove! Pick 1st date
|
|
#NextValue(self.b16Value2_2.storage, LoadUnit.b32Data2.storage >> 16), # TODO: Remove! Pick 2nd date
|
|
NextValue(fpu2.fs1, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data2.storage[0:16])),
|
|
NextValue(fpu2.fs2, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data2.storage[16:32])),
|
|
NextValue(LoadUnit.b9Offset2.storage, LoadUnit.b9Offset2.storage + 1), # Move on to next entry
|
|
|
|
NextState("Loader_EXEC1")
|
|
).Else( # MEM wait cycles
|
|
NextValue(self.Loader_Delay, self.Loader_Delay + 1), # Increment
|
|
)
|
|
)
|
|
Loader_fsm.act("Loader_EXEC1",
|
|
NextValue(self.b16Status.storage[2], True), # Current status added
|
|
If(LoadUnit.b9Offset1.storage == 1, # As pointer already moved ahead 1!
|
|
NextValue(fpu1.fmul, True), # 1st ADD requested
|
|
NextValue(fpu2.fmul, True),
|
|
).Else(
|
|
NextValue(fpu1.fmadd, True), # 2nd ... last MUL/ADD requested
|
|
NextValue(fpu2.fmadd, True),
|
|
),
|
|
NextValue(fpu1.fready, False), # Engage trigger FPU#1
|
|
NextValue(fpu2.fready, False), # Engage trigger FPU#2
|
|
NextState("Loader_EXEC2")
|
|
)
|
|
Loader_fsm.act("Loader_EXEC2",
|
|
NextValue(self.b16Status.storage[3], True), # Current status added
|
|
NextValue(self.b16Status.storage[8], fpu1.fready), # TODO: Remove!
|
|
NextValue(self.b16Status.storage[9], fpu2.fready), # TODO: Remove!
|
|
If(fpu1.fready & fpu2.fready,
|
|
If(LoadUnit.b9Offset1.storage == 1, # As pointer already moved ahead 1! (Actually: Entry #0)
|
|
NextValue(fpu1.fmul, False), # Clear command request FPU#1
|
|
NextValue(fpu2.fmul, False), # Clear command request FPU#2
|
|
).Else( # Entries 1 .. (maxlen-1)
|
|
NextValue(fpu1.fmadd, False), # Clear command request FPU#1
|
|
NextValue(fpu2.fmadd, False), # Clear command request FPU#2
|
|
),
|
|
NextValue(fpu1.fs3, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, fpu1.fresult[16:32])), # Sum will be used for fmadd.s
|
|
NextValue(fpu2.fs3, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, fpu2.fresult[16:32])), # Sum will be used for fmadd.s
|
|
If(LoadUnit.b9Offset1.storage < self.b9ArrayWordLen.storage, # Words 0 .. 255
|
|
NextState("Loader_LOAD2")
|
|
).Else( # Finally prepare ADD both result sums (on FPU#1 only!)
|
|
NextValue(fpu1.fs1, fpu1.fresult),
|
|
NextValue(fpu1.fs2, fpu2.fresult),
|
|
NextState("Loader_EXEC3")
|
|
)
|
|
)
|
|
)
|
|
Loader_fsm.act("Loader_EXEC3",
|
|
NextValue(self.b16Status.storage[4], True), # Current status added
|
|
NextValue(fpu1.fadd, True), # Final ADD requested
|
|
NextValue(fpu1.fready, False), # Engage trigger FPU#1 (only!)
|
|
NextState("Loader_EXEC4")
|
|
)
|
|
Loader_fsm.act("Loader_EXEC4",
|
|
NextValue(self.b16Status.storage[5], True), # Current status added
|
|
If(fpu1.fready,
|
|
NextValue(fpu1.fadd, False), # Clear command request FPU#1
|
|
NextValue(self.b16Result1.storage, fpu1.fresult[16:32]), # Pick result (little endian, high word!)
|
|
NextValue(self.b16Result2.storage, fpu2.fresult[16:32]), # Useless (control only ...)
|
|
NextValue(self.b16Status.storage[15], True), # Indicate readyness ...
|
|
NextValue(self.bReady, True), # Indicate readyness (LED on!) (TODO: Remove!)
|
|
NextState("Loader_IDLE")
|
|
)
|
|
)
|
|
|
|
self.sync += [ # Show individual FPU states
|
|
self.b16FPUStates.storage[0:8].eq(fpu1.FPU_state[0:8]),
|
|
self.b16FPUStates.storage[8:16].eq(fpu2.FPU_state[0:8]),
|
|
] |