예제 #1
0
    def forward(self, x):
        residual = x

        out = self.conv1(x)
        #out = self.bn1(out)
        out = BFPActivation.transform_activation_online(
            out, self.exp_bit, self.mantisa_bit, -1)
        out = self.relu(out)

        out = self.conv2(out)
        #out = self.bn2(out)
        out = BFPActivation.transform_activation_online(
            out, self.exp_bit, self.mantisa_bit, -1)
        out = self.relu(out)

        out = self.conv3(out)
        #out = self.bn3(out)
        out = BFPActivation.transform_activation_online(
            out, self.exp_bit, self.mantisa_bit, -1)

        if self.downsample is not None:
            # Get a max of two list
            #max_exp_act_list =  np.maximum.reduce([self.opt_exp_act_list[self.start_exp_ind+2], self.opt_exp_act_list[self.start_exp_ind+3]]).tolist()
            residual = self.downsample(x)
            # bfp quantize both tensor for shortcut using the max exponent list
            # since they have the same exp list, no need for realignment
            # residual = BFPActivation.transform_activation_online(residual, self.exp_bit,
            #                                                         self.mantisa_bit, self.opt_exp_act_list[self.start_exp_ind+3])
            #out = BFPActivation.transform_activation_offline(out, self.exp_bit, self.mantisa_bit, max_exp_act_list)
        # else:
        # bfp quantize both tensor for shortcut using the third exponent list
        # residual = BFPActivation.transform_activation_online(residual, self.exp_bit, self.mantisa_bit, self.opt_exp_act_list[self.start_exp_ind+2])
        # Get the exponent from out
        out_exp = Utils.find_exponent(out, self.exp_bit)
        out_exp = Utils.find_max_exponent(out_exp,
                                          quant_dim=len(out.shape) - 1)
        out_exp = Utils.find_max_exponent(out_exp,
                                          quant_dim=len(out.shape) - 2)
        out_exp = Utils.find_max_exponent(out_exp, quant_dim=0)
        out_exp = out_exp.int().cpu().data.tolist()
        # Get the exponent from input
        in_exp = Utils.find_exponent(residual, self.exp_bit)
        in_exp = Utils.find_max_exponent(in_exp,
                                         quant_dim=len(residual.shape) - 1)
        in_exp = Utils.find_max_exponent(in_exp,
                                         quant_dim=len(residual.shape) - 2)
        in_exp = Utils.find_max_exponent(in_exp, quant_dim=0)
        in_exp = in_exp.int().cpu().data.tolist()
        # Quantize accordint to the max
        max_exp = np.maximum.reduce([out_exp, in_exp]).tolist()
        residual = BFPActivation.transform_activation_offline(
            residual, self.exp_bit, self.mantisa_bit, max_exp)
        out = BFPActivation.transform_activation_offline(
            out, self.exp_bit, self.mantisa_bit, max_exp)
        out += residual
        out = self.relu(out)

        return out
예제 #2
0
    def forward(self, x):
        if self.use_res_connect:
            '''
            max_exp_act_list =  np.maximum.reduce([self.opt_exp_act_list[self.start_exp_ind+2], self.opt_exp_act_list[self.start_exp_ind-1]]).tolist()
            out = self.conv(x)
            x = BFPActivation.transform_activation_offline(x, self.exp_bit, self.mantisa_bit, max_exp_act_list)
            out = BFPActivation.transform_activation_offline(out, self.exp_bit, self.mantisa_bit, max_exp_act_list)
            return x + out
            '''
            out = self.conv(x)
            # Get the exponent from out
            out_exp = Utils.find_exponent(out, self.exp_bit)
            out_exp = Utils.find_max_exponent(out_exp,
                                              quant_dim=len(out.shape) - 1)
            out_exp = Utils.find_max_exponent(out_exp,
                                              quant_dim=len(out.shape) - 2)
            out_exp = Utils.find_max_exponent(out_exp, quant_dim=0)
            out_exp = out_exp.int().cpu().data.tolist()
            # Get the exponent from input
            in_exp = Utils.find_exponent(x, self.exp_bit)
            in_exp = Utils.find_max_exponent(in_exp,
                                             quant_dim=len(x.shape) - 1)
            in_exp = Utils.find_max_exponent(in_exp,
                                             quant_dim=len(x.shape) - 2)
            in_exp = Utils.find_max_exponent(in_exp, quant_dim=0)
            in_exp = in_exp.int().cpu().data.tolist()
            # Quantize accordint to the max
            max_exp = np.maximum.reduce([out_exp, in_exp]).tolist()
            x = BFPActivation.transform_activation_offline(
                x, self.exp_bit, self.mantisa_bit, max_exp)
            out = BFPActivation.transform_activation_offline(
                out, self.exp_bit, self.mantisa_bit, max_exp)
            return x + out

        else:
            return self.conv(x)