Half-precision floats handling
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
bfloat16nn/libmodules/bfloat16nncore.py

261 lines
13 KiB

#!/usr/bin/env python3
#
# bfloat16nncore.py
#
# bfloat16 neural network processing core logic
#
# History:
# --------
# 21.04.21/KQ Initial version
#
from migen import *
from migen.fhdl.specials import Memory
from litex.soc.interconnect.csr import AutoCSR, CSRStatus, CSRStorage, CSRField, CSRAccess
from litex.soc.integration.doc import AutoDoc, ModuleDoc
from litex.soc.interconnect.csr import *
from libmodules.dramtransfer import DRAM2FPGA, FPGA2DRAM
from libmodules.bfloat16processor import bfloat16Processor
class bfloat16NeuralNetworkCore(Module, AutoCSR, AutoDoc, ModuleDoc):
"""
bfloat16NN core class:
Usage:
######
1. Freeze operations by setting ``bEnable`` to false (0)
2. Load ``b32DRAMAddress`` with a 32-bit DRAM memory pointer.
3. Finally, enable processing by setting ``bEnable`` to true (1).
4. bfloat16NN processing will now run async. to CPU until reset or s.a. (1.).
CPU from now on will write to DRAM only (yet, L2 cache will have to be flushed),
all data will be picked up by FPGA automatically ...
Inputs:
#######
:b32DRAMAddress: New DRAM address from where to load into local memory
:b32Sentinel: Write control word to last address (same as [b32DRAMAddress+511] value)
:bEnable: To enable running (after data preparation)
Outputs:
########
:b16Result: Processing result
:bReady: Ready indication (wire to LED ... ;)
"""
def __init__(self, RAMWaitTime=128, LUCacheSize=8, LoadUnit=None, StoreUnit=None):
# Inputs
self.b32DRAMLoadAddress = CSRStorage(32, reset_less=False,
fields=[CSRField("LoadAddress", size=32, description="*Field*: 32-Bit value")],
description="""
Load value (32-bit DRAM address).
""")
self.b32Sentinel = CSRStorage(32, reset_less=False,
fields=[CSRField("Sentinel", size=32, description="*Field*: 32-Bit value")],
description="""
Control value
""")
self.bEnable = CSRStorage(1, reset_less=False,
fields=[CSRField("Enable", size=1, description="*Field*: bit", values=[
("0", "DISABLED", "bfloat16NN not active"),
("1", "ENABLED", "bfloat16NN active"),
])
],
description="""
Enable free run
""")
# Outputs
self.b16Status = CSRStorage(16, reset_less=False,
fields=[CSRField("Status", size=16, description="*Field*: 16-Bit value")],
description="""
Processing stati
""")
self.b16Value1 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description="""
Float register 1
""")
self.b16Value2 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description="""
Float register 2
""")
self.b16Value3 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description="""
Float register 3
""")
self.b16Value4 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description="""
Float register 4
""")
self.b16Value5 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description="""
Float register 5
""")
self.b16Value6 = CSRStorage(16, reset_less=False,
fields=[CSRField("Value", size=16, description="*Field*: 16-Bit value")],
description="""
Float register 6
""")
self.b16Result1 = CSRStorage(16, reset_less=False,
fields=[CSRField("Result1", size=16, description="*Field*: 16-Bit value")],
description="""
Processing result 1
""")
self.b16Result2 = CSRStorage(16, reset_less=False,
fields=[CSRField("Result2", size=16, description="*Field*: 16-Bit value")],
description="""
Processing result 2
""")
self.bReady = Signal() # To be wired to data pin ... ;)
# Local vars.
# - none yet -
#---------------- Load unit (LU) -------------------------------------------------------------
LU_fsm = FSM(reset_state="LU_IDLE") # FSM starts idling ...
self.submodules += LU_fsm
self.LU_CacheOffset = Signal(9, reset_less=True) # 0..511 log2_int(LUCacheSize, False)) # Cache reading offset (0..(Size-1))=>Bits)
self.LU_CacheValid = Signal() # Indicate loaded LU cache
self.LU_CacheDelay = Signal(11, reset_less=True) # Evaluate load length in cycles (2048 max.)
LU_fsm.act("LU_IDLE", # If cache not valid fill it!
If(~self.LU_CacheValid & self.bEnable.storage, # Invalid cache & run enabled ...
NextValue(LoadUnit.b32Address.storage, self.b32DRAMLoadAddress.storage),
NextValue(self.LU_CacheOffset, 0), # Adjust pointer (local reader), 4-byte width=32-bit
NextValue(self.LU_CacheDelay, 2), # Reset load delay counter (but inkl. 1st & last cycle)
NextState("LU_LOAD1")
).Elif(~self.bEnable.storage, # Cleared enable?
NextValue(self.LU_CacheValid, False), # Enforce cache invalidation!
)
)
LU_fsm.act("LU_LOAD1", # Engage!
NextValue(LoadUnit.bEnable.storage, 1), # Trigger DRAM transfer to cache
NextState("LU_LOAD2")
)
LU_fsm.act("LU_LOAD2", # Wait for termination of transfer ...
If(LoadUnit.bValid.storage, # Data avail.?
NextValue(self.LU_CacheValid, 1), # Declare cache valid
NextValue(LoadUnit.bEnable.storage, 0), # Stop DRAM transfer to cache
NextState("LU_IDLE") # Yap!
).Else(
If(self.LU_CacheDelay < 2047, # MAX-1!
NextValue(self.LU_CacheDelay, self.LU_CacheDelay + 1),
)
# TODO: Permit timeout indication ...
)
)
#---------------- bfloat16 FPUs -------------------------------------------------------------
self.submodules.fpu1 = fpu1 = bfloat16Processor() # Integrate bfloat16 FPU
self.submodules.fpu2 = fpu2 = bfloat16Processor() # Integrate another one!
#---------------- Loaded data testing --------------------------------------------------
Loader_fsm = FSM(reset_state="Loader_IDLE") # FSM starts idling ...
self.submodules += Loader_fsm
self.Loader_Delay = Signal(32, reset_less=True)
self.Loader_Active = Signal()
Loader_fsm.act("Loader_IDLE",
If(self.LU_CacheValid & ~self.Loader_Active, # Enter if not active already
NextValue(self.Loader_Active, True), # Loader up & running
NextValue(self.Loader_Delay, 0), # Reset read delay timer
NextValue(LoadUnit.b9Offset.storage, LUCacheSize - 1), # Adjust offset to read sentinel
NextValue(self.b16Result1.storage, 0), # Indicate # delays
NextValue(self.b16Result2.storage, 0), # Indicate # delays
NextValue(self.b16Status.storage[0], True), # Current status
NextValue(self.b16Value1.storage, 0), # Nothing loaded so far ...
NextValue(self.b16Value2.storage, 0), # Nothing loaded so far ...
NextValue(self.b16Value3.storage, 0), # Nothing loaded so far ...
NextValue(self.bReady, False), # LED off!
NextState("Loader_LOAD1")
).Elif(~self.bEnable.storage, # Externally aborted?
NextValue(self.b16Status.storage, 0), # Current status: inactive
NextValue(self.Loader_Active, False), # Reset in sync w/ global activation
)
)
Loader_fsm.act("Loader_LOAD1",
NextValue(self.b16Status.storage[1], True), # Current status added
If(LoadUnit.b32Data.storage == self.b32Sentinel.storage, # Valid last entry?
NextValue(LoadUnit.b9Offset.storage, 0), # 1st value offset preparation
NextState("Loader_LOAD2")
).Elif(~self.bEnable.storage, # Enable withdrawn?
NextState("Loader_IDLE") # Abort!
)
)
Loader_fsm.act("Loader_LOAD2",
NextValue(self.b16Status.storage[2], True), # Current status added
If(self.Loader_Delay > RAMWaitTime,
NextValue(self.b16Value1.storage, LoadUnit.b32Data.storage & 0xFFFF), # Pick 1st date
NextValue(self.b16Value2.storage, LoadUnit.b32Data.storage >> 16), # Pick 2nd date
NextValue(fpu1.fs1, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data.storage[0:16])),
NextValue(fpu1.fs2, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data.storage[16:32])),
NextValue(LoadUnit.b9Offset.storage, 1), # 2nd value offset preparation
NextValue(self.Loader_Delay, 0), # Reset delay
NextState("Loader_LOAD3")
).Else( # MEM wait cycles
NextValue(self.Loader_Delay, self.Loader_Delay + 1), # Increment
)
)
Loader_fsm.act("Loader_LOAD3",
NextValue(self.b16Status.storage[3], True), # Current status added
If(self.Loader_Delay > RAMWaitTime,
NextValue(self.b16Value3.storage, LoadUnit.b32Data.storage & 0xFFFF), # Pick 3rd date
NextValue(self.b16Value4.storage, LoadUnit.b32Data.storage >> 16), # Pick 4th date
NextValue(fpu1.fs3, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data.storage[0:16])),
NextValue(fpu2.fs1, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data.storage[16:32])),
NextValue(LoadUnit.b9Offset.storage, 2), # 3rd value offset preparation
NextValue(self.Loader_Delay, 0), # Reset delay
NextState("Loader_LOAD4")
).Else( # MEM wait cycles
NextValue(self.Loader_Delay, self.Loader_Delay + 1), # Increment
)
)
Loader_fsm.act("Loader_LOAD4",
NextValue(self.b16Status.storage[4], True), # Current status added
If(self.Loader_Delay > RAMWaitTime,
NextValue(self.b16Value5.storage, LoadUnit.b32Data.storage & 0xFFFF), # Pick 5th date
NextValue(self.b16Value6.storage, LoadUnit.b32Data.storage >> 16), # Pick 6th date
NextValue(fpu2.fs2, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data.storage[0:16])),
NextValue(fpu2.fs3, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, LoadUnit.b32Data.storage[16:32])),
NextState("Loader_EXEC1")
).Else( # MEM wait cycles
NextValue(self.Loader_Delay, self.Loader_Delay + 1), # Increment
)
)
Loader_fsm.act("Loader_EXEC1",
NextValue(self.b16Status.storage[5], True), # Current status added
NextValue(fpu1.fadd, True), # This command requested
NextValue(fpu2.fnmsub, True), # This command requested
NextValue(fpu1.fready, False), # Engage trigger FPU#1
NextValue(fpu2.fready, False), # Engage trigger FPU#2
NextState("Loader_EXEC2")
)
Loader_fsm.act("Loader_EXEC2",
NextValue(self.b16Status.storage[6], True), # Current status added
If(fpu1.fready & fpu2.fready,
NextValue(fpu1.fadd, False), # Clear command request FPU#1
NextValue(fpu2.fnmsub, False), # Clear command request FPU#2
NextValue(self.b16Result1.storage, fpu1.fresult[16:32]), # Pick result (little endian, high word!)
NextValue(self.b16Result2.storage, fpu2.fresult[16:32]), # Pick result (little endian, high word!)
NextValue(self.b16Status.storage[15], True), # Indicate readyness ...
NextValue(self.bReady, True), # Indicate readyness (LED on!) (TODO: Remove!)
NextState("Loader_IDLE")
)
)