예제 #1
0
 def forward(self, input):
     binput = binaryfunction.BinaryFunc().apply(input)
     bweight = binaryfunction.BinaryFunc().apply(self.weight)
     output = F.conv2d(binput, bweight, self.bias,
                       self.stride, self.padding,
                       self.dilation, self.groups)
     return output
예제 #2
0
 def forward(self, input):
     # w_b = a * sign(w)
     # dim(w)=DimIn*DimOUT  dim(a)= 1*DimOUT
     bw = binaryfunction.BinaryFunc().apply(self.weight)
     scale_b = self.weight.abs().mean(0).view(1,
                                              self.weight.size(1)).detach()
     scale_bw = bw * scale_b
     # input_b = sign(input)
     # dim(input) = N*DimIn
     binput = binaryfunction.BinaryFunc().apply(input)
     # dim(a_input) = Nx1
     si = torch.mean(torch.abs(input), dim=1, keepdim=True).detach()
     scale_binput = binput * si
     output = F.linear(input=scale_binput, weight=scale_bw, bias=self.bias)
     return output
예제 #3
0
 def forward(self, input):
     # w_b = a * sign(w)
     # dim(w)=DimIn*DimOUT  dim(a)= 1*DimOUT
     bw = binaryfunction.BinaryFunc().apply(self.weight)
     scale_b = self.weight.abs().mean(0).view(1,
                                              self.weight.size(1)).detach()
     scale_bw = bw * scale_b
     output = F.linear(input=input, weight=scale_bw, bias=self.bias)
     return output
예제 #4
0
 def forward(self, input):
     # muti binary activate
     # cinput = clip(input + v,0,1)
     # binput = safesign(cinput - 0.5) (相当于 cinput 取 round(0,1),再*2-1 (-1,1))
     if self.shiftalphas.device != input.device:
         self.shiftalphas = self.shiftalphas.to(input.device)
         self.betas = self.betas.to(input.device)
     bweight = binaryfunction.BinaryFunc().apply(self.weight)
     outputs = []
     for index in range(self.binarynum):
         sinput = input + self.shiftalphas[index]
         cinput = torch.clamp(sinput,min=0,max=1)
         binput = binaryfunction.BinaryFunc().apply(cinput-0.5)
         boutput = F.conv2d(binput, bweight, self.bias,
                           self.stride, self.padding,
                           self.dilation, self.groups)
         boutput = boutput * self.betas[index]
         outputs.append(boutput)
     # BNUM N C H` W`
     output = torch.sum(torch.stack(outputs,0),dim=0,keepdim=False)
     return output
    def forward(self, input):

        # w_b = a * sign(w)
        bw = binaryfunction.BinaryFunc().apply(self.weight)
        scale_b = torch.mean(torch.abs(self.weight),
                             dim=[1, 2, 3],
                             keepdim=True).detach()
        #scale_b = self.weight.abs().view(self.weight.size(0), -1).mean(-1).view(self.weight.size(0),1,1,1).detach()
        scale_bw = bw * scale_b
        # input_b = sign(input)
        binput = binaryfunction.BinaryFunc().apply(input)
        boutput = F.conv2d(binput,
                           scale_bw,
                           bias=self.bias,
                           stride=self.stride,
                           padding=self.padding,
                           dilation=self.dilation,
                           groups=self.groups)
        # compute output scale feature map ()
        # Equal to the scaling factor for each activation value of the convolution fast
        os = self.getScaleFeatureMap(input).detach()
        output = boutput * os
        return output
예제 #6
0
 def forward(self, input):
     # muti binary activate
     # cinput = clip(input + v,0,1)
     # binput = safesign(cinput - 0.5) (相当于 cinput 取 round(0,1),再*2-1 (-1,1))
     if self.shiftalphas.device != input.device:
         self.shiftalphas = self.shiftalphas.to(input.device)
         self.betas = self.betas.to(input.device)
     bweight = binaryfunction.BinaryFunc().apply(self.weight)
     outputs = []
     for index in range(self.binarynum):
         #input: N C H W    C--> 1 C 1 1      C:  binary_num C
         sinput = input + self.shiftalphas[index].unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(input)
         cinput = torch.clamp(sinput,min=0,max=1)
         binput = binaryfunction.BinaryFunc().apply(cinput-0.5)
         boutput = F.conv2d(binput, bweight, self.bias,
                           self.stride, self.padding,
                           self.dilation, self.groups)
         # N C H` W`   *  beta[index]        Beta: binary_num   
         #                                    这里应该修改成 binary num的形状减少计算量,也就是每个二值化基一个因子就可
         boutput = boutput * self.betas[index]
         outputs.append(boutput)
     # BNUM N C H` W`
     output = torch.sum(torch.stack(outputs,0),dim=0,keepdim=False)
     return output
예제 #7
0
    def forward(self, input):

        # w_b = a * sign(w)
        bw = binaryfunction.BinaryFunc().apply(self.weight)
        scale_b = torch.mean(torch.abs(self.weight),
                             dim=[1, 2, 3],
                             keepdim=True).detach()
        #scale_b = self.weight.abs().view(self.weight.size(0), -1).mean(-1).view(self.weight.size(0),1,1,1).detach()
        scale_bw = bw * scale_b
        boutput = F.conv2d(input,
                           scale_bw,
                           bias=self.bias,
                           stride=self.stride,
                           padding=self.padding,
                           dilation=self.dilation,
                           groups=self.groups)

        return boutput
예제 #8
0
 def forward(self, input):
     binput = input
     bweight = binaryfunction.BinaryFunc().apply(self.weight)
     output = F.linear(input=binput, weight=bweight, bias=self.bias)
     return output
예제 #9
0
def validate(data_loader, model, criterion, epoch, monitors, args, logger):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    batch_time = AverageMeter()

    total_sample = len(data_loader.sampler)
    batch_size = data_loader.batch_size
    steps_per_epoch = math.ceil(total_sample / batch_size)

    logger.info('Validation: %d samples (%d per mini-batch)', total_sample,
                batch_size)

    model.eval()
    end_time = time.time()
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        with torch.no_grad():
            inputs = inputs.to(args.device)
            targets = targets.to(args.device)

            outputs = model(inputs)
            closs = criterion(outputs, targets)

            #  weight regularzation loss
            regularzation_loss = 0
            for name, param in model.named_parameters():
                regularzation_loss += torch.sum(
                    torch.pow(1.0 - torch.abs(param), 2))

            # sqnr losss of weight
            sqnr_loss = 0
            for name, param in model.named_parameters():
                if "conv" in name.lower() and "weight" in name.lower():
                    bparam = binaryfunction.BinaryFunc().apply(param)
                    sqnr_loss += 10 * torch.log10(
                        torch.sum(torch.pow(param, 2)) /
                        (torch.sum(torch.pow(param - bparam, 2)) + 1e-5))
            sqnr_loss = sqnr_loss * (-1)

            loss = closs + args.weight_regloss * regularzation_loss + args.weight_sqnrloss * sqnr_loss
            # loss = closs

            acc1, acc5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))
            batch_time.update(time.time() - end_time)
            end_time = time.time()

            if (batch_idx + 1) % args.print_freq == 0:
                for m in monitors:
                    m.update(
                        epoch, batch_idx + 1, steps_per_epoch, 'Validation', {
                            'Loss': losses,
                            "Closs": closs,
                            "Regloss": regularzation_loss,
                            "Sqnrloss": sqnr_loss,
                            'Top1': top1,
                            'Top5': top5,
                            'BatchTime': batch_time
                        })

    logger.info('==> Top1: %.3f    Top5: %.3f    Loss: %.3f\n', top1.avg,
                top5.avg, losses.avg)
    return top1.avg, top5.avg, losses.avg