@ -3,7 +3,7 @@
#
# bfloat16processor.py
#
# bfloat16 processing
# bfloat16 processing (1 bit sign, 8 bit exponent, 7 bit mantissa)
#
# History:
# --------
@ -50,16 +50,20 @@ class bfloat16Processor(Module):
self . e1 = Signal ( ( 8 , True ) , reset_less = True ) # Signed exponents!
self . e2 = Signal ( ( 8 , True ) , reset_less = True )
self . e3 = Signal ( ( 8 , True ) , reset_less = True )
self . m1 = Signal ( ( 23 + 1 + 3 , False ) , reset_less = True ) # Unsigned mantissas! TODO: Verify sign!
self . m2 = Signal ( ( 24 + 1 + 3 , False ) , reset_less = True ) # 23 bits + 1bit (1.xx = 0x800000)
self . m3 = Signal ( ( 25 + 1 + 3 , True ) , reset_less = True ) # + Sign + R(0)/Guard & Sticky bits
#self.m1 = Signal((23+1+3,False), reset_less=True) # Unsigned mantissas! TODO: Verify sign!
self . m1 = Signal ( ( 7 + 2 + 3 , False ) , reset_less = True ) # Unsigned mantissas! TODO: Verify sign!
#self.m2 = Signal((24+1+3,False), reset_less=True) # 23 bits + 1bit (1.xx = 0x800000)
self . m2 = Signal ( ( 7 + 2 + 3 , False ) , reset_less = True ) # 7 bits + 1bit (1.xx = 0x800000) + 2 spare
#self.m3 = Signal((25+1+3,True), reset_less=True) # + Sign + R(0)/Guard & Sticky bits
self . m3 = Signal ( ( 8 + 2 + 3 , True ) , reset_less = True ) # + Sign + R(0)/Guard & Sticky bits
self . lm3 = Signal ( ( 64 , True ) , reset_less = True ) # MUL long result
self . s32 = Signal ( ( 32 , True ) , reset_less = True ) # Signed 32-bit
self . s_bit = Signal ( ) # Sticky bit (for rounding control)
self . branch1 = Signal ( ) # Branch helpers
self . branch2 = Signal ( )
self . i = Signal ( 5 ) # Loop counter, range 0..31
#self.i = Signal(5) # Loop counter, range 0..31
self . i = Signal ( 4 ) # Loop counter, range 0..15
FPU_fsm = FSM ( reset_state = " FPU_IDLE " ) # FSM starts idling ...
self . submodules + = FPU_fsm
@ -72,16 +76,20 @@ class bfloat16Processor(Module):
NextValue ( self . sign2 , self . fs2 [ 31 ] ^ self . fsub ) , # Invert sign for subtraction!
NextValue ( self . e1 , self . fs1 [ 23 : 31 ] - 127 ) ,
NextValue ( self . e2 , self . fs2 [ 23 : 31 ] - 127 ) ,
NextValue ( self . m1 , Cat ( 0 , 0 , 0 , self . fs1 [ 0 : 23 ] , 1 , 0 ) ) , # | 0x00800000 + R/G/S bits
NextValue ( self . m2 , Cat ( 0 , 0 , 0 , self . fs2 [ 0 : 23 ] , 1 , 0 ) ) , # | 0x00800000 + R/G/S bits
#NextValue(self.m1, Cat(0,0,0, self.fs1[0:23], 1, 0)), # | 0x00800000 + R/G/S bits
NextValue ( self . m1 , Cat ( 0 , 0 , 0 , self . fs1 [ 0 : 7 ] , 1 , 0 ) ) , # | 0x00800000 + R/G/S bits
#NextValue(self.m2, Cat(0,0,0, self.fs2[0:23], 1, 0)), # | 0x00800000 + R/G/S bits
NextValue ( self . m2 , Cat ( 0 , 0 , 0 , self . fs2 [ 0 : 7 ] , 1 , 0 ) ) , # | 0x00800000 + R/G/S bits
NextState ( " FADD1 " )
) . Elif ( ( self . fmin | self . fmax | self . fmadd | self . fmsub | self . fnmadd | self . fnmsub | self . fmul | self . fdiv ) & ~ self . fready , # Triggers set & ready flag reset externally!
NextValue ( self . sign1 , self . fs1 [ 31 ] ) ,
NextValue ( self . sign2 , self . fs2 [ 31 ] ) ,
NextValue ( self . e1 , self . fs1 [ 23 : 31 ] - 127 ) ,
NextValue ( self . e2 , self . fs2 [ 23 : 31 ] - 127 ) ,
NextValue ( self . m1 , Cat ( self . fs1 [ 0 : 23 ] , 1 , 0 ) ) , # | 0x00800000
NextValue ( self . m2 , Cat ( self . fs2 [ 0 : 23 ] , 1 , 0 ) ) , # | 0x00800000
#NextValue(self.m1, Cat(self.fs1[0:23], 1, 0)), # | 0x00800000
NextValue ( self . m1 , Cat ( self . fs1 [ 0 : 7 ] , 1 , 0 , 0 , 0 , 0 ) ) , # | 0x00800000
#NextValue(self.m2, Cat(self.fs2[0:23], 1, 0)), # | 0x00800000
NextValue ( self . m2 , Cat ( self . fs2 [ 0 : 7 ] , 1 , 0 , 0 , 0 , 0 ) ) , # | 0x00800000
If ( self . fdiv , # Division
NextState ( " FDIV1 " ) ,
) . Elif ( self . fmin , # Minimum
@ -94,7 +102,8 @@ class bfloat16Processor(Module):
) . Elif ( self . fsqrt & ~ self . fready , # Trigger set & ready flag reset externally!
NextValue ( self . sign1 , self . fs1 [ 31 ] ) ,
NextValue ( self . e1 , self . fs1 [ 23 : 31 ] - 127 ) ,
NextValue ( self . m1 , Cat ( self . fs1 [ 0 : 23 ] , 1 , 0 ) ) , # | 0x00800000
#NextValue(self.m1, Cat(self.fs1[0:23], 1, 0)), # | 0x00800000
NextValue ( self . m1 , Cat ( self . fs1 [ 0 : 7 ] , 1 , 0 , 0 , 0 , 0 ) ) , # | 0x00800000
NextState ( " FSQRT1 " ) ,
)
)
@ -205,7 +214,8 @@ class bfloat16Processor(Module):
FPU_fsm . act ( " FADD6 " ,
NextValue ( self . FPU_state , 6 ) ,
# 6. Normalization of result: Overflow
If ( self . m3 [ 24 ] , # & 0x01000000,
#If(self.m3[24], # & 0x01000000,
If ( self . m3 [ 7 + 1 ] , # & 0x01000000,
NextValue ( self . m3 , self . m3 >> 1 ) , # Adjust mantissa & increment exponent
NextValue ( self . e3 , self . e3 + 1 )
) . Else (
@ -216,7 +226,8 @@ class bfloat16Processor(Module):
FPU_fsm . act ( " FADD7 " ,
# 7. Normalization: Result
NextValue ( self . FPU_state , 7 ) ,
If ( ~ self . m3 [ 23 ] & ( self . i < 23 ) , # & 0x00800000 (limit to max. loops)
#If(~self.m3[23] & (self.i < 23), # & 0x00800000 (limit to max. loops)
If ( ~ self . m3 [ 7 ] & ( self . i < 7 ) , # & 0x00800000 (limit to max. loops)
NextValue ( self . m3 , self . m3 << 1 ) , # Subtraction normalization
NextValue ( self . e3 , self . e3 - 1 ) ,
NextValue ( self . i , self . i + 1 ) , # Count loops ...
@ -231,7 +242,8 @@ class bfloat16Processor(Module):
)
FPU_fsm . act ( " FADD8 " ,
NextValue ( self . FPU_state , 8 ) ,
If ( self . m3 [ 24 ] , # & 0x01000000, # Overflow?
#If(self.m3[24], # & 0x01000000, # Overflow?
If ( self . m3 [ 7 + 1 ] , # & 0x01000000, # Overflow?
NextValue ( self . m3 , self . m3 >> 1 ) , # Adjust mantissa & increment exponent
NextValue ( self . e3 , self . e3 + 1 )
) ,
@ -241,7 +253,8 @@ class bfloat16Processor(Module):
FPU_fsm . act ( " self.fresult " , # Result contruction & possible rounding
NextValue ( self . FPU_state , 9 ) ,
# 6. Build the actual resulting float
NextValue ( self . fresult , Cat ( self . m3 [ 0 : 23 ] , self . e3 + 127 , self . sign3 ) ) ,
#NextValue(self.fresult, Cat(self.m3[0:23], self.e3+127, self.sign3)),
NextValue ( self . fresult , Cat ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , self . m3 [ 0 : 7 ] , self . e3 + 127 , self . sign3 ) ) ,
NextValue ( self . fready , 1 ) , # Indicate ready to main decoder
NextState ( " FPU_IDLE " )
)
@ -326,8 +339,10 @@ class bfloat16Processor(Module):
NextValue ( self . sign2 , self . fs3 [ 31 ] ^ ( self . fmsub | self . fnmsub ) ) , # Invert sign for subtraction!
NextValue ( self . e1 , self . e3 ) ,
NextValue ( self . e2 , self . fs3 [ 23 : 31 ] - 127 ) ,
NextValue ( self . m1 , Cat ( 0 , 0 , 0 , self . m3 [ 0 : 23 ] , 1 , 0 ) ) , # | 0x00800000 + R/G/S bits
NextValue ( self . m2 , Cat ( 0 , 0 , 0 , self . fs3 [ 0 : 23 ] , 1 , 0 ) ) , # | 0x00800000 + R/G/S bits
#NextValue(self.m1, Cat(0,0,0, self.m3[0:23], 1, 0)), # | 0x00800000 + R/G/S bits
NextValue ( self . m1 , Cat ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , self . m3 [ 0 : 7 ] , 1 , 0 ) ) , # | 0x00800000 + R/G/S bits
#NextValue(self.m2, Cat(0,0,0, self.fs3[0:23], 1, 0)), # | 0x00800000 + R/G/S bits
NextValue ( self . m2 , Cat ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , self . fs3 [ 0 : 7 ] , 1 , 0 ) ) , # | 0x00800000 + R/G/S bits
NextState ( " FADD1 " )
)
@ -363,7 +378,8 @@ class bfloat16Processor(Module):
)
)
FPU_fsm . act ( " FDIV2 " ,
If ( self . i < 24 ,
#If(self.i < 24,
If ( self . i < 8 ,
NextValue ( self . FPU_state , 2 ) ,
If ( self . m1 < self . m2 ,
NextValue ( self . m3 , self . m3 << 1 ) , # Append a zero
@ -376,7 +392,8 @@ class bfloat16Processor(Module):
) . Else ( # Loop exceeded
# 4. Normalization
NextValue ( self . FPU_state , 3 ) ,
If ( ~ self . m3 [ 23 ] , # & 0x00800000
#If(~self.m3[23], # & 0x00800000
If ( ~ self . m3 [ 7 ] , # & 0x00800000
NextValue ( self . m3 , self . m3 << 1 ) , # Subtraction normalization
NextValue ( self . e3 , self . e3 - 1 ) ,
) . Else (
@ -398,7 +415,8 @@ class bfloat16Processor(Module):
NextState ( " FPU_IDLE " )
) . Else ( # Better fast, than accurate! Use Newton-Raphson in S/W for better accuracy!
# Goldschmidt's algorithm (only 1 digit after decimal point ok, error varies, s.b)
If ( ( self . m1 [ 0 : 23 ] != 0 ) | ( self . e1 == 1 ) , # Not 2^x (m==0!) and x!=1
#If((self.m1[0:23] != 0) | (self.e1 == 1), # Not 2^x (m==0!) and x!=1
If ( ( self . m1 [ 0 : 7 ] != 0 ) | ( self . e1 == 1 ) , # Not 2^x (m==0!) and x!=1
#return sqrt_approx(f, 0x0004B0D2); // Minimized error (max. 3.5%)
NextValue ( self . branch1 , 1 ) , # Use 0x0004B0D2 for minimized error (<= 3.5%)
) . Else (