#!/usr/bin/env python3 # # bfloat16processor.py # # bfloat16 processing (1 bit sign, 8 bit exponent, 7 bit mantissa) # # History: # -------- # 22.04.21/KQ Initial version # from migen import * from migen.fhdl.specials import Memory from litex.soc.interconnect.csr import * from litex.soc.integration.doc import AutoDoc, ModuleDoc class bfloat16Processor(Module): """ bfloat16 FPU logic """ def __init__(self): # Inputs self.fs1 = Signal(32, reset_less=True) # Float register #1 self.fs2 = Signal(32, reset_less=True) # Float register #2 self.fs3 = Signal(32, reset_less=True) # Float register #3 # Output self.fresult = Signal(32, reset_less=True) # Float result # F-Extension: Job triggers self.fadd = Signal() self.fsub = Signal() self.fmul = Signal() self.fdiv = Signal() self.fsqrt = Signal() self.fmadd = Signal() self.fmsub = Signal() self.fnmadd = Signal() self.fnmsub = Signal() self.fmin = Signal() self.fmax = Signal() self.fready = Signal() # Indicate ready # Calculation support variables self.sign1 = Signal() # Sign of floats self.sign2 = Signal() self.sign3 = Signal() self.e1 = Signal((8,True), reset_less=True) # Signed exponents! self.e2 = Signal((8,True), reset_less=True) self.e3 = Signal((8,True), reset_less=True) #self.m1 = Signal((23+1+3,False), reset_less=True) # Unsigned mantissas! TODO: Verify sign! self.m1 = Signal((7+2+3,False), reset_less=True) # Unsigned mantissas! TODO: Verify sign! #self.m2 = Signal((24+1+3,False), reset_less=True) # 23 bits + 1bit (1.xx = 0x800000) self.m2 = Signal((7+2+3,False), reset_less=True) # 7 bits + 1bit (1.xx = 0x800000) + 2 spare #self.m3 = Signal((25+1+3,True), reset_less=True) # + Sign + R(0)/Guard & Sticky bits self.m3 = Signal((8+2+3,True), reset_less=True) # + Sign + R(0)/Guard & Sticky bits self.lm3 = Signal((64,True), reset_less=True) # MUL long result self.s32 = Signal((32,True), reset_less=True) # Signed 32-bit self.s_bit = Signal() # Sticky bit (for rounding control) self.branch1 = Signal() # Branch helpers self.branch2 = Signal() #self.i = Signal(5) # Loop counter, range 0..31 self.i = Signal(4) # Loop counter, range 0..15 FPU_fsm = FSM(reset_state="FPU_IDLE") # FSM starts idling ... self.submodules += FPU_fsm self.FPU_state = Signal(9, reset_less=True) # Debugging support FPU_fsm.act("FPU_IDLE", NextValue(self.FPU_state, 0), If((self.fadd | self.fsub) & ~self.fready, # Triggers set & ready flag reset externally! NextValue(self.sign1, self.fs1[31]), NextValue(self.sign2, self.fs2[31] ^ self.fsub), # Invert sign for subtraction! NextValue(self.e1, self.fs1[23:31] - 127), NextValue(self.e2, self.fs2[23:31] - 127), NextValue(self.m1, Cat(0,0,0, self.fs1[16:23], 1, 0)), # | 0x00800000 + R/G/S bits NextValue(self.m2, Cat(0,0,0, self.fs2[16:23], 1, 0)), # | 0x00800000 + R/G/S bits NextState("FADD1") ).Elif((self.fmin | self.fmax | self.fmadd | self.fmsub | self.fnmadd | self.fnmsub | self.fmul | self.fdiv) & ~self.fready, # Triggers set & ready flag reset externally! NextValue(self.sign1, self.fs1[31]), NextValue(self.sign2, self.fs2[31]), NextValue(self.e1, self.fs1[23:31] - 127), NextValue(self.e2, self.fs2[23:31] - 127), NextValue(self.m1, Cat(self.fs1[16:23], 1, 0, 0,0,0)), # | 0x00800000 NextValue(self.m2, Cat(self.fs2[16:23], 1, 0, 0,0,0)), # | 0x00800000 If(self.fdiv, # Division NextState("FDIV1"), ).Elif(self.fmin, # Minimum NextState("FMIN1"), ).Elif(self.fmax, # Maximum NextState("FMAX1") ).Else( # Multiplication variants NextState("FMUL1"), ) ).Elif(self.fsqrt & ~self.fready, # Trigger set & ready flag reset externally! NextValue(self.sign1, self.fs1[31]), NextValue(self.e1, self.fs1[23:31] - 127), NextValue(self.m1, Cat(self.fs1[16:23], 1, 0, 0,0,0)), # | 0x00800000 NextState("FSQRT1"), ) ) FPU_fsm.act("FADD1", NextValue(self.FPU_state, 1), # 1. Verify valid ranges 1st! If(((self.fs1[0:31] == 0x7FFFFFFF) | (self.fs2[0:31] == 0x7FFFFFFF)) | ((self.sign1 ^ self.sign2) & ((self.e1 == -1) & (self.e2 == -1))), NextValue(self.fresult, 0x7FFFFFFF), # NAN NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif(self.e1 == -1, # Infinity NextValue(self.fresult, self.fs1), # Return infinity NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif(self.e2 == -1, # Infinity NextValue(self.fresult, self.fs2), # Return infinity NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif(self.fs1[0:31] == 0, # Nothing to add? (w/o sign!) If(self.fsub, # Subtract yields negative result! NextValue(self.fresult, self.fs2 ^ 0x80000000), # Invert sign ).Elif(self.fmsub | self.fnmsub, # 0*x=>0! 0-fs3 = +fs3! NextValue(self.fresult, self.fs3 ^ 0x80000000), # Invert sign ).Elif(self.fmadd | self.fnmadd, # 0*x=>0! 0+fs3 = fs3! NextValue(self.fresult, self.fs3), # Ready! ).Else( # Straight add NextValue(self.fresult, self.fs2), # Ready! ), NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif((self.fadd | self.fsub) & (self.fs2[0:31] == 0), # Nothing to add? (w/o sign!) NextValue(self.fresult, self.fs1), # Ready! NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif((self.fmadd | self.fmsub | self.fnmadd | self.fnmsub) & ((self.e2 == 0) & (self.m2 == 0)), # Nothing to add (w/o sign!) NextState("FRESULT") # Just supply (normalized finally!) result from multiplication! ).Else( # Ok, valid floats supplied ... NextValue(self.s_bit, 0), NextValue(self.branch1, 0), # Reset helpers NextValue(self.branch2, 0), NextState("FADD2") ) ) FPU_fsm.act("FADD2", # 2. Compare exponents: The higher one will be taken, the lower one adjusted If(self.e1 < self.e2, NextValue(self.FPU_state, 21), If(self.m1[0], NextValue(self.s_bit, 1)), # Keep shifted out bits (ORed sticky bit) NextValue(self.m1, self.m1 >> 1), NextValue(self.e1, self.e1 + 1), NextValue(self.branch1, 1), ).Elif(self.e1 > self.e2, NextValue(self.FPU_state, 22), If(self.m2[0], NextValue(self.s_bit, 1)), # Keep shifted out bits (ORed sticky bit) NextValue(self.m2, self.m2 >> 1), NextValue(self.e2, self.e2 + 1), NextValue(self.branch2, 1), ).Else( NextValue(self.FPU_state, 23), If(self.branch1, NextValue(self.m1, self.m1 | self.s_bit)), # Add sticky bit (if any) If(self.branch2, NextValue(self.m2, self.m2 | self.s_bit)), NextState("FADD3") ) ) FPU_fsm.act("FADD3", NextValue(self.FPU_state, 3), # 3. Add mantissas (as both are of same base now) If(~self.sign1 & ~self.sign2, # Negotiate sign -> ADD/SUB NextValue(self.m3, self.m1 + self.m2) ).Else( If(self.sign1 & ~self.sign2, NextValue(self.m3, self.m2 - self.m1) ).Else( If(~self.sign1 & self.sign2, NextValue(self.m3, self.m1 - self.m2) ).Else( NextValue(self.m3, -(self.m1 + self.m2)) ) ) ), NextState("FADD4") ) FPU_fsm.act("FADD4", NextValue(self.FPU_state, 4), # 4. Retrieve sign & unsigned absolute value If(self.m3 < 0, NextValue(self.sign3, 1), # Pull sign NextValue(self.m3, -self.m3) # Absolute value pick ).Else( # m3 positive anyway NextValue(self.sign3, 0), # Remember ... ), NextValue(self.e3, self.e1), # Starter value (e1/e2 are the same by now ...) NextState("FADD5") ) FPU_fsm.act("FADD5", NextValue(self.FPU_state, 5), # 5. Rounding to nearest/even (FCS_FRM=0x00) If(self.m3[0:3] == 0x7, # Remainder (all set?): REMAINDER(0) + GUARD(MSB) + STICKYBIT (ORed rest) NextValue(self.s_bit, 1) # Indicate rounding ).Else( NextValue(self.s_bit, 0), # Reset otherwise ), NextValue(self.m3, self.m3 >> 3), # Remove R/G/S bits NextState("FADD6") ) FPU_fsm.act("FADD6", NextValue(self.FPU_state, 6), # 6. Normalization of result: Overflow #If(self.m3[24], # & 0x01000000, If(self.m3[7+1], # & 0x01000000, NextValue(self.m3, self.m3 >> 1), # Adjust mantissa & increment exponent NextValue(self.e3, self.e3 + 1) ).Else( NextValue(self.i, 0), # Reset for normalization restraining NextState("FADD7") ) ) FPU_fsm.act("FADD7", # 7. Normalization: Result NextValue(self.FPU_state, 7), #If(~self.m3[23] & (self.i < 23), # & 0x00800000 (limit to max. loops) If(~self.m3[7] & (self.i < 7), # & 0x00800000 (limit to max. loops) NextValue(self.m3, self.m3 << 1), # Subtraction normalization NextValue(self.e3, self.e3 - 1), NextValue(self.i, self.i + 1), # Count loops ... ).Else( If(self.s_bit, # Do we need rounding?! NextValue(self.m3, self.m3 + self.s_bit), NextState("FADD8") # Adjust possible overflow ... ).Else( # Nope, all ready NextState("FRESULT") ) ) ) FPU_fsm.act("FADD8", NextValue(self.FPU_state, 8), #If(self.m3[24], # & 0x01000000, # Overflow? If(self.m3[7+1], # & 0x01000000, # Overflow? NextValue(self.m3, self.m3 >> 1), # Adjust mantissa & increment exponent NextValue(self.e3, self.e3 + 1) ), NextState("FRESULT") ) # End of fadd.s processing FPU_fsm.act("FRESULT", # Result contruction & possible rounding NextValue(self.FPU_state, 9), # 6. Build the actual resulting float #NextValue(self.fresult, Cat(self.m3[0:23], self.e3+127, self.sign3)), NextValue(self.fresult, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0 ,0,0,0,0, self.m3[0:7], self.e3+127, self.sign3)), NextValue(self.fready, 1), # Indicate ready to main decoder NextState("FPU_IDLE") ) FPU_fsm.act("FMUL1", NextValue(self.FPU_state, 1), # 0. Verify valid ranges 1st! If((self.fs1[0:31] == 0x7FFFFFFF) | (self.fs2[0:31] == 0x7FFFFFFF), NextValue(self.fresult, 0x7FFFFFFF), # NAN NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif(self.e1 == -1, # Infinity NextValue(self.fresult, self.fs1), # Return infinity NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif(self.e2 == -1, # Infinity NextValue(self.fresult, self.fs2), # Return infinity NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif((self.fs1[0:31] == 0) | (self.fs2[0:31] == 0), # Nothing to multiply? (w/o sign!) NextValue(self.fresult, 0), # Result will be zero ... NextValue(self.fready, 1), NextState("FPU_IDLE") ).Else( # Ok, valid floats supplied ... NextValue(self.sign3, self.sign1 ^ self.sign2), # 1. Calculate result sign NextValue(self.e3, self.e1 + self.e2), # 2. Calculate resulting exponent (add!) NextValue(self.lm3, self.m1 * self.m2), # 3. Significants multiplication (result size: 2x (sizeof(mantissa)+1) !) NextState("FMUL2") ) ) FPU_fsm.act("FMUL2", NextValue(self.FPU_state, 2), # 4. MSB set in significants (i.e. bit[45])? # Bitoffset: 48 32 16 0 If(self.lm3[47], # & 0x0000800000000000, TODO: Verify bit# (45 or 47?)! NextValue(self.lm3, self.lm3 >> 1), # Normalize result: Overflow NextValue(self.e3, self.e3 + 1), ), If(self.fmul, # Regular multiplication NextState("FMUL3") # Do the rounding! ).Else( # Fused multiply/add? W/O rounding! NextState("FMUL5") ) ) FPU_fsm.act("FMUL3", # 5. Rounding to nearest/even (FCS_FRM=0x00) If(self.lm3[22] & self.lm3[23], # & 0xC00000) == 0xC00000 Remainder (to be skipped): RESULTBIT(0) + REMAINDERBIT(MSB) set? If(self.lm3[0:22] != 0, # Sticky-Bit S (ORed rest) set? #Bit:48 32 16 0 (>>23) # 0000 2000 0000 0000 (Overflow 1.x) NextValue(self.lm3, (self.lm3 & 0x00007FFFFF800000) + 0x800000), # Add remainder NextState("FMUL4") ).Else( NextState("FMUL5") ) ).Else( NextState("FMUL5") ) ) FPU_fsm.act("FMUL4", # Overflow normalization # Bit:48 32 16 0 If(self.lm3[47], # & 0x0000800000000000 NextValue(self.lm3, self.lm3 >> 1), # Normalize result: Overflow NextValue(self.e3, self.e3 + 1) ), NextState("FMUL5") ) FPU_fsm.act("FMUL5", # 6. Construction of result NextValue(self.m3, (self.lm3 >> 23) & 0x7FFFFF), # TODO: e3=se3 omitted ok? If(self.fmul, # Simple multiplication NextState("FRESULT") ).Else( # Fused multiply-add? NextValue(self.sign3, self.sign3 ^ (self.fnmadd | self.fnmsub)), # Negate mult. result w/ fxxx NextState("FMADD1") ) ) # End of fmul.s processing FPU_fsm.act("FMADD1", # Result->fs1: sign3/e3/m3 -> sign1/e1/m1 & fs1, fs3->fs2: fs3 -> sign2/e2/m2 & fs2 NextValue(self.sign1, self.sign3), # Negate mult. result w/ fxxx NextValue(self.sign2, self.fs3[31] ^ (self.fmsub | self.fnmsub)), # Invert sign for subtraction! NextValue(self.e1, self.e3), NextValue(self.e2, self.fs3[23:31] - 127), NextValue(self.m1, Cat(0,0,0, self.m3[0:7], 1, 0)), # | 0x00800000 + R/G/S bits NextValue(self.m2, Cat(0,0,0, self.fs3[16:23], 1, 0)), # | 0x00800000 + R/G/S bits NextValue(self.fs1, Cat(0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, self.m3[0:7], (self.e3+127)[0:8], self.sign3)), NextValue(self.fs2, self.fs3), NextState("FADD1") # Add fs1 & fs2! ) FPU_fsm.act("FDIV1", NextValue(self.FPU_state, 1), # 0. Verify valid ranges 1st! If((self.fs1[0:31] == 0x7FFFFFFF) | (self.fs2[0:31] == 0x7FFFFFFF) | ((self.fs1[0:31] == 0) & (self.fs2[0:31] == 0)), NextValue(self.fresult, 0x7FFFFFFF), # NAN NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif(self.e1 == -1, # Infinity NextValue(self.fresult, self.fs1), # Return infinity NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif(self.e2 == -1, # Infinity NextValue(self.fresult, self.fs2), # Return infinity NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif(self.fs2 == 0, # Division by zero? If(self.sign3, NextValue(self.fresult, 0xFF800000), # - Infinity ).Else( NextValue(self.fresult, 0x7F800000), # + Infinity ), NextValue(self.fready, 1), NextState("FPU_IDLE") ).Else( # Ok, valid floats supplied ... NextValue(self.sign3, self.sign1 ^ self.sign2), # 1. Calculate result sign NextValue(self.e3, self.e1 - self.e2), # 2. Calculate resulting exponent (subtract!) NextValue(self.m3, 0), # 3. Significant preparation NextValue(self.i, 0), # Loop counter NextState("FDIV2") ) ) FPU_fsm.act("FDIV2", #If(self.i < 24, If(self.i < 8, NextValue(self.FPU_state, 2), If(self.m1 < self.m2, NextValue(self.m3, self.m3 << 1), # Append a zero NextValue(self.m1, self.m1 << 1), ).Else( # Append a one NextValue(self.m3, (self.m3 << 1) | 1), NextValue(self.m1, (self.m1 - self.m2) << 1), ), NextValue(self.i, self.i + 1) ).Else( # Loop exceeded # 4. Normalization NextValue(self.FPU_state, 3), #If(~self.m3[23], # & 0x00800000 If(~self.m3[7], # & 0x00800000 NextValue(self.m3, self.m3 << 1), # Subtraction normalization NextValue(self.e3, self.e3 - 1), ).Else( NextState("FRESULT") ) ) ) # End of fdiv.s processing FPU_fsm.act("FSQRT1", NextValue(self.FPU_state, 1), # 1. Verify valid ranges 1st! If((self.fs1[0:31] == 0x7FFFFFFF) | self.sign1, NextValue(self.fresult, 0x7FFFFFFF), # NAN NextValue(self.fready, 1), NextState("FPU_IDLE") ).Elif(self.e1 == -1, # Infinity NextValue(self.fresult, self.fs1), # Return +/- infinity NextValue(self.fready, 1), NextState("FPU_IDLE") ).Else( # Better fast, than accurate! Use Newton-Raphson in S/W for better accuracy! # Goldschmidt's algorithm (only 1 digit after decimal point ok, error varies, s.b) #If((self.m1[0:23] != 0) | (self.e1 == 1), # Not 2^x (m==0!) and x!=1 If((self.m1[0:7] != 0) | (self.e1 == 1), # Not 2^x (m==0!) and x!=1 #return sqrt_approx(f, 0x0004B0D2); // Minimized error (max. 3.5%) NextValue(self.branch1, 1), # Use 0x0004B0D2 for minimized error (<= 3.5%) ).Else( NextValue(self.branch1, 0), # Use 0x00000000, only for 2^x exact, others up to ~6% error ), NextValue(self.s32, self.fs1), # Pick up float value for manipulation NextState("FSQRT2") ) ) FPU_fsm.act("FSQRT2", NextValue(self.FPU_state, 2), # 1 << 23 /* Subtract 2^m. (0x40000000) */ # >> 1; /* Divide by 2. */ # 1 << 29 /* Add ((b + 1) / 2) * 2^m. */ If(self.branch1, NextValue(self.s32, ((self.s32 - 0x00800000) >> 1) + (0x20000000 - 0x0004B0D2)), # Error minimizer term! ).Else( NextValue(self.s32, ((self.s32 - 0x00800000) >> 1) + 0x20000000), ), NextState("FSQRT3") ) FPU_fsm.act("FSQRT3", NextValue(self.FPU_state, 3), NextValue(self.fresult, self.s32), # Just map value straight ... NextValue(self.fready, 1), # Indicate ready to main decoder NextState("FPU_IDLE") ) # End of fsqrt.s processing FPU_fsm.act("FMIN1", # Simple sign compare ahead If(self.sign1 ^ self.sign2, # Sign mismatch? That's easy! If(self.sign1, # f1 negative -> hence smaller (min!) NextValue(self.fresult, self.fs1), # Just map value straight ... ).Else( # f2 negative/min NextValue(self.fresult, self.fs2), # Just map value straight ... ) ).Elif(self.e1 < self.e2, # Same sign: Compare exponents, then (maybe) mantissas # f1 smaller absolute number? If(self.sign1, # But negative? NextValue(self.fresult, self.fs2), # Just map value straight ... ).Else( # Positive NextValue(self.fresult, self.fs1), # Just map value straight ... ) ).Elif(self.e2 < self.e1, # f2 smaller absolute number? If(self.sign1, # But negative? NextValue(self.fresult, self.fs1), # Just map value straight ... ).Else( # Positive NextValue(self.fresult, self.fs2), # Just map value straight ... ) ).Else( # Equal exponents? If(self.m1 < self.m2, # Compare mantissas: f1 smaller If(self.sign1, # But negative? NextValue(self.fresult, self.fs2), # Just map value straight ... ).Else( # Positive NextValue(self.fresult, self.fs1), # Just map value straight ... ) ).Else( # f2 smaller/equal If(self.sign1, # But negative? NextValue(self.fresult, self.fs1), # Just map value straight ... ).Else( # Positive NextValue(self.fresult, self.fs2), # Just map value straight ... ) ) ), NextValue(self.fready, 1), # Indicate ready to main decoder NextState("FPU_IDLE") ) # End of fmin.s processing FPU_fsm.act("FMAX1", # Simple sign compare ahead If(self.sign1 ^ self.sign2, # Sign mismatch? That's easy! If(self.sign1, # f1 negative -> hence smaller (min!) NextValue(self.fresult, self.fs2), # Just map value straight ... ).Else( # f2 negative/min NextValue(self.fresult, self.fs1), # Just map value straight ... ) ).Elif(self.e1 < self.e2, # Same sign: Compare exponents, then (maybe) mantissas # f1 smaller absolute number? If(self.sign1, # But negative? NextValue(self.fresult, self.fs1), # Just map value straight ... ).Else( # Positive NextValue(self.fresult, self.fs2), # Just map value straight ... ) ).Elif(self.e2 < self.e1, # f2 smaller absolute number? If(self.sign1, # But negative? NextValue(self.fresult, self.fs2), # Just map value straight ... ).Else( # Positive NextValue(self.fresult, self.fs1), # Just map value straight ... ) ).Else( # Equal exponents? If(self.m1 < self.m2, # Compare mantissas: f1 smaller If(self.sign1, # But negative? NextValue(self.fresult, self.fs1), # Just map value straight ... ).Else( # Positive NextValue(self.fresult, self.fs2), # Just map value straight ... ) ).Else( # f2 smaller/equal If(self.sign1, # But negative? NextValue(self.fresult, self.fs2), # Just map value straight ... ).Else( # Positive NextValue(self.fresult, self.fs1), # Just map value straight ... ) ) ), NextValue(self.fready, 1), # Indicate ready to main decoder NextState("FPU_IDLE") ) # End of fmax.s processing