bfloat16nn/libmodules/bfloat16nncore.py

288 lines
14 KiB
Python

#!/usr/bin/env python3
#
# bfloat16nncore.py
#
# bfloat16 neural network processing core logic
#
# History:
# --------
# 21.04.21/KQ Initial version
#
from migen import *
from migen.fhdl.specials import Memory
from litex.soc.interconnect.csr import AutoCSR, CSRStatus, CSRStorage, CSRField, CSRAccess
from litex.soc.integration.doc import AutoDoc, ModuleDoc
from litex.soc.interconnect.csr import *
from libmodules.dramtransfer import DRAM2FPGA, FPGA2DRAM
from libmodules.bfloat16processor import bfloat16Processor
class bfloat16NeuralNetworkCore(Module, AutoCSR, AutoDoc, ModuleDoc):
"""
bfloat16NN core class:
Usage:
######
1. Freeze operations by setting ``bEnable`` to false (0)
2. Load ``b32DRAMAddress`` with a 32-bit DRAM memory pointer.
3. Finally, enable processing by setting ``bEnable`` to true (1).
4. bfloat16NN processing will now run async. to CPU until reset or s.a. (1.).
CPU from now on will write to DRAM only (yet, L2 cache will have to be flushed),
all data will be picked up by FPGA automatically ...
Inputs:
#######
:b32DRAMAddress: New DRAM address from where to load into local memory
:b32Sentinel: Write control word to last address (same as [b32DRAMAddress+511] value)
:bEnable: To enable running (after data preparation)
:b9ArrayWordLen: Number of words used for calculation of scalar (inner) product
Outputs:
########
:b16Result: Processing result
:bReady: Ready indication (wire to LED ... ;)
"""
def __init__(self, RAMWaitTime=128, LUCacheSize=8, LoadUnit=None, StoreUnit=None):
# Inputs
self.b32DRAMLoadAddress = CSRStorage(32, reset_less=False,
fields=[CSRField("LoadAddress", size=32, description="*Field*: 32-Bit value")],
description="""
Load value (32-bit DRAM address).
""")
self.b32Sentinel = CSRStorage(32, reset_less=False,
fields=[CSRField("Sentinel", size=32, description="*Field*: 32-Bit value")],
description="""
Control value
""")
self.bEnable = CSRStorage(1, reset_less=False,
fields=[CSRField("Enable", size=1, description="*Field*: bit", values=[
("0", "DISABLED", "bfloat16NN not active"),
("1", "ENABLED", "bfloat16NN active"),
])
],
description="""
Enable free run
""")
self.b9ArrayWordLen = CSRStorage(9, reset_less=False,
fields=[CSRField("ArrayWordLen", size=9, description="*Field*: 9-Bit value")],
description="""
Word length of array used for calculation
""")
# Outputs
self.b16Status = CSRStorage(16, reset_less=False,
fields=[CSRField("Status", size=16, description="*Field*: 16-Bit value")],
description="""
Processing stati
""")
self.b16FPUStates = CSRStorage(16, reset_less=False,
fields=[CSRField("FPUStates", size=16, description="*Field*: 16-Bit value")],
description="""
FPU states: Low FPU#1, High FPU#2
""")
""" TODO: Remove!
self.b16Value1_1 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description="
FPU#1 Float register 1
")
self.b16Value1_2 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description="
#FPU#1 Float register 2
")
self.b16Value2_1 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description="
#FPU#2 Float register 1
")
self.b16Value2_2 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description=" " "
#FPU#2 Float register 2
" " ")
"""
self.b16Result1 = CSRStorage(16, reset_less=False,
fields=[CSRField("Result1", size=16, description="*Field*: 16-Bit value")],
description="""
FPU#1 Processing result
""")
self.b16Result2 = CSRStorage(16, reset_less=False,
fields=[CSRField("Result2", size=16, description="*Field*: 16-Bit value")],
description="""
FPU#2 Processing result
""")
self.bReady = Signal() # To be wired to data pin ... ;)
# Local vars.
# - none yet -
#---------------- Load unit (LU) -------------------------------------------------------------
LU_fsm = FSM(reset_state="LU_IDLE") # FSM starts idling ...
self.submodules += LU_fsm
self.LU_CacheOffset = Signal(9, reset_less=True) # 0..511 log2_int(LUCacheSize, False)) # Cache reading offset (0..(Size-1))=>Bits)
self.LU_CacheValid = Signal() # Indicate loaded LU cache
self.LU_CacheDelay = Signal(11, reset_less=True) # Evaluate load length in cycles (2048 max.)
LU_fsm.act("LU_IDLE", # If cache not valid fill it!
If(~self.LU_CacheValid & self.bEnable.storage, # Invalid cache & run enabled ...
NextValue(LoadUnit.b32Address.storage, self.b32DRAMLoadAddress.storage),
NextValue(self.LU_CacheOffset, 0), # Adjust pointer (local reader), 4-byte width=32-bit
NextValue(self.LU_CacheDelay, 2), # Reset load delay counter (but inkl. 1st & last cycle)
NextState("LU_LOAD1")
).Elif(~self.bEnable.storage, # Cleared enable?
NextValue(self.LU_CacheValid, False), # Enforce cache invalidation!
)
)
LU_fsm.act("LU_LOAD1", # Engage!
NextValue(LoadUnit.bEnable.storage, 1), # Trigger DRAM transfer to cache
NextState("LU_LOAD2")
)
LU_fsm.act("LU_LOAD2", # Wait for termination of transfer ...
If(LoadUnit.bValid.storage, # Data avail.?
NextValue(self.LU_CacheValid, 1), # Declare cache valid
NextValue(LoadUnit.bEnable.storage, 0), # Stop DRAM transfer to cache
NextState("LU_IDLE") # Yap!
).Else(
If(self.LU_CacheDelay < 2047, # MAX-1!
NextValue(self.LU_CacheDelay, self.LU_CacheDelay + 1),
)
# TODO: Permit timeout indication ...
)
)
#---------------- bfloat16 FPUs -------------------------------------------------------------
NFPUCORES=2 # No. of FPUs used
self.submodules.fpu1 = fpu1 = bfloat16Processor() # Integrate bfloat16 FPU
self.submodules.fpu2 = fpu2 = bfloat16Processor() # Integrate another one!
#---------------- Loaded data testing --------------------------------------------------
Loader_fsm = FSM(reset_state="Loader_IDLE") # FSM starts idling ...
self.submodules += Loader_fsm
self.Loader_Delay = Signal(32, reset_less=True)
self.Loader_Active = Signal()
Loader_fsm.act("Loader_IDLE",
If(self.LU_CacheValid & ~self.Loader_Active, # Enter if not active already
NextValue(self.Loader_Active, True), # Loader up & running
NextValue(self.Loader_Delay, 0), # Reset read delay timer
NextValue(LoadUnit.b9Offset1.storage, LUCacheSize - 1), # Adjust offset to read sentinel
NextValue(LoadUnit.b9Offset2.storage, LUCacheSize >> 1), # Adjust offset to start of 2nd array
NextValue(self.b16Result1.storage, 0), # Indicate # delays
NextValue(self.b16Result2.storage, 0), # Indicate # delays
#NextValue(self.b16Value1_1.storage, 0), # TODO: Remove! Nothing loaded so far ...
#NextValue(self.b16Value1_2.storage, 0),
#NextValue(self.b16Value2_1.storage, 0),
#NextValue(self.b16Value2_2.storage, 0),
NextValue(self.bReady, False), # LED off!
NextState("Loader_LOAD1")
).Elif(~self.bEnable.storage, # Externally aborted?
NextValue(self.b16Status.storage, 0), # Current status: inactive
NextValue(self.Loader_Active, False), # Reset in sync w/ global activation
)
)
Loader_fsm.act("Loader_LOAD1",
NextValue(self.b16Status.storage[0], True), # Current status added
If(LoadUnit.b32Data1.storage == self.b32Sentinel.storage, # Valid last entry?
NextValue(LoadUnit.b9Offset1.storage, 0), # 1st value offset preparation
NextState("Loader_LOAD2")
).Elif(~self.bEnable.storage, # Enable withdrawn?
NextState("Loader_IDLE") # Abort!
)
)
#-----> LOOP ENTRY ! (2nd loop onward: fs3 already prepared!)
Loader_fsm.act("Loader_LOAD2",
NextValue(self.b16Status.storage[1], True), # Current status added
If(self.Loader_Delay > RAMWaitTime, # Required only for 1st entry ...
# FPU#1
#NextValue(self.b16Value1_1.storage, LoadUnit.b32Data1.storage & 0xFFFF), # TODO: Remove! Pick 1st date
#NextValue(self.b16Value1_2.storage, LoadUnit.b32Data1.storage >> 16), # TODO: Remove! Pick 2nd date
NextValue(fpu1.fs1, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data1.storage[0:16])),
NextValue(fpu1.fs2, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data1.storage[16:32])),
NextValue(LoadUnit.b9Offset1.storage, LoadUnit.b9Offset1.storage + 1), # Move on to next entry
# FPU#2
#NextValue(self.b16Value2_1.storage, LoadUnit.b32Data2.storage & 0xFFFF), # TODO: Remove! Pick 1st date
#NextValue(self.b16Value2_2.storage, LoadUnit.b32Data2.storage >> 16), # TODO: Remove! Pick 2nd date
NextValue(fpu2.fs1, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data2.storage[0:16])),
NextValue(fpu2.fs2, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data2.storage[16:32])),
NextValue(LoadUnit.b9Offset2.storage, LoadUnit.b9Offset2.storage + 1), # Move on to next entry
NextState("Loader_EXEC1")
).Else( # MEM wait cycles
NextValue(self.Loader_Delay, self.Loader_Delay + 1), # Increment
)
)
Loader_fsm.act("Loader_EXEC1",
NextValue(self.b16Status.storage[2], True), # Current status added
If(LoadUnit.b9Offset1.storage == 1, # As pointer already moved ahead 1!
NextValue(fpu1.fmul, True), # 1st ADD requested
NextValue(fpu2.fmul, True),
).Else(
NextValue(fpu1.fmadd, True), # 2nd ... last MUL/ADD requested
NextValue(fpu2.fmadd, True),
),
NextValue(fpu1.fready, False), # Engage trigger FPU#1
NextValue(fpu2.fready, False), # Engage trigger FPU#2
NextState("Loader_EXEC2")
)
Loader_fsm.act("Loader_EXEC2",
NextValue(self.b16Status.storage[3], True), # Current status added
NextValue(self.b16Status.storage[8], fpu1.fready), # TODO: Remove!
NextValue(self.b16Status.storage[9], fpu2.fready), # TODO: Remove!
If(fpu1.fready & fpu2.fready,
If(LoadUnit.b9Offset1.storage == 1, # As pointer already moved ahead 1! (Actually: Entry #0)
NextValue(fpu1.fmul, False), # Clear command request FPU#1
NextValue(fpu2.fmul, False), # Clear command request FPU#2
).Else( # Entries 1 .. (maxlen-1)
NextValue(fpu1.fmadd, False), # Clear command request FPU#1
NextValue(fpu2.fmadd, False), # Clear command request FPU#2
),
NextValue(fpu1.fs3, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, fpu1.fresult[16:32])), # Sum will be used for fmadd.s
NextValue(fpu2.fs3, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, fpu2.fresult[16:32])), # Sum will be used for fmadd.s
If(LoadUnit.b9Offset1.storage < self.b9ArrayWordLen.storage, # Words 0 .. 255
NextState("Loader_LOAD2")
).Else( # Finally prepare ADD both result sums (on FPU#1 only!)
NextValue(fpu1.fs1, fpu1.fresult),
NextValue(fpu1.fs2, fpu2.fresult),
NextState("Loader_EXEC3")
)
)
)
Loader_fsm.act("Loader_EXEC3",
NextValue(self.b16Status.storage[4], True), # Current status added
NextValue(fpu1.fadd, True), # Final ADD requested
NextValue(fpu1.fready, False), # Engage trigger FPU#1 (only!)
NextState("Loader_EXEC4")
)
Loader_fsm.act("Loader_EXEC4",
NextValue(self.b16Status.storage[5], True), # Current status added
If(fpu1.fready,
NextValue(fpu1.fadd, False), # Clear command request FPU#1
NextValue(self.b16Result1.storage, fpu1.fresult[16:32]), # Pick result (little endian, high word!)
NextValue(self.b16Result2.storage, fpu2.fresult[16:32]), # Useless (control only ...)
NextValue(self.b16Status.storage[15], True), # Indicate readyness ...
NextValue(self.bReady, True), # Indicate readyness (LED on!) (TODO: Remove!)
NextState("Loader_IDLE")
)
)
self.sync += [ # Show individual FPU states
self.b16FPUStates.storage[0:8].eq(fpu1.FPU_state[0:8]),
self.b16FPUStates.storage[8:16].eq(fpu2.FPU_state[0:8]),
]