def forward(self, x): residual = x out = self.conv1(x) #out = self.bn1(out) out = BFPActivation.transform_activation_online( out, self.exp_bit, self.mantisa_bit, -1) out = self.relu(out) out = self.conv2(out) #out = self.bn2(out) out = BFPActivation.transform_activation_online( out, self.exp_bit, self.mantisa_bit, -1) out = self.relu(out) out = self.conv3(out) #out = self.bn3(out) out = BFPActivation.transform_activation_online( out, self.exp_bit, self.mantisa_bit, -1) if self.downsample is not None: # Get a max of two list #max_exp_act_list = np.maximum.reduce([self.opt_exp_act_list[self.start_exp_ind+2], self.opt_exp_act_list[self.start_exp_ind+3]]).tolist() residual = self.downsample(x) # bfp quantize both tensor for shortcut using the max exponent list # since they have the same exp list, no need for realignment # residual = BFPActivation.transform_activation_online(residual, self.exp_bit, # self.mantisa_bit, self.opt_exp_act_list[self.start_exp_ind+3]) #out = BFPActivation.transform_activation_offline(out, self.exp_bit, self.mantisa_bit, max_exp_act_list) # else: # bfp quantize both tensor for shortcut using the third exponent list # residual = BFPActivation.transform_activation_online(residual, self.exp_bit, self.mantisa_bit, self.opt_exp_act_list[self.start_exp_ind+2]) # Get the exponent from out out_exp = Utils.find_exponent(out, self.exp_bit) out_exp = Utils.find_max_exponent(out_exp, quant_dim=len(out.shape) - 1) out_exp = Utils.find_max_exponent(out_exp, quant_dim=len(out.shape) - 2) out_exp = Utils.find_max_exponent(out_exp, quant_dim=0) out_exp = out_exp.int().cpu().data.tolist() # Get the exponent from input in_exp = Utils.find_exponent(residual, self.exp_bit) in_exp = Utils.find_max_exponent(in_exp, quant_dim=len(residual.shape) - 1) in_exp = Utils.find_max_exponent(in_exp, quant_dim=len(residual.shape) - 2) in_exp = Utils.find_max_exponent(in_exp, quant_dim=0) in_exp = in_exp.int().cpu().data.tolist() # Quantize accordint to the max max_exp = np.maximum.reduce([out_exp, in_exp]).tolist() residual = BFPActivation.transform_activation_offline( residual, self.exp_bit, self.mantisa_bit, max_exp) out = BFPActivation.transform_activation_offline( out, self.exp_bit, self.mantisa_bit, max_exp) out += residual out = self.relu(out) return out
def forward(self, x): if self.use_res_connect: ''' max_exp_act_list = np.maximum.reduce([self.opt_exp_act_list[self.start_exp_ind+2], self.opt_exp_act_list[self.start_exp_ind-1]]).tolist() out = self.conv(x) x = BFPActivation.transform_activation_offline(x, self.exp_bit, self.mantisa_bit, max_exp_act_list) out = BFPActivation.transform_activation_offline(out, self.exp_bit, self.mantisa_bit, max_exp_act_list) return x + out ''' out = self.conv(x) # Get the exponent from out out_exp = Utils.find_exponent(out, self.exp_bit) out_exp = Utils.find_max_exponent(out_exp, quant_dim=len(out.shape) - 1) out_exp = Utils.find_max_exponent(out_exp, quant_dim=len(out.shape) - 2) out_exp = Utils.find_max_exponent(out_exp, quant_dim=0) out_exp = out_exp.int().cpu().data.tolist() # Get the exponent from input in_exp = Utils.find_exponent(x, self.exp_bit) in_exp = Utils.find_max_exponent(in_exp, quant_dim=len(x.shape) - 1) in_exp = Utils.find_max_exponent(in_exp, quant_dim=len(x.shape) - 2) in_exp = Utils.find_max_exponent(in_exp, quant_dim=0) in_exp = in_exp.int().cpu().data.tolist() # Quantize accordint to the max max_exp = np.maximum.reduce([out_exp, in_exp]).tolist() x = BFPActivation.transform_activation_offline( x, self.exp_bit, self.mantisa_bit, max_exp) out = BFPActivation.transform_activation_offline( out, self.exp_bit, self.mantisa_bit, max_exp) return x + out else: return self.conv(x)