Exemplo n.º 1
0
        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
        w1 = w1[idx1.tolist(), :, :, :].clone()
        m1.weight.data = w1.clone()
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()

# conv-bn merge
fuse_model_vgg(newmodel)
prune_model_parameters = sum([param.nelement() for param in newmodel.parameters()])

print("save pruned merged model...")
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, prune_model_save_path)
#torch.save(newmodel.state_dict(), prune_model_save_path)

real_prune_acc = test(newmodel)


# ##########################
# time and flops
# ##########################
newmodel.to(cpu_device)
pruned_forward_time, pruned_flops, pruned_params = calc_time_and_flops(random_input, newmodel)
print("origin net forward time {}   vs   prune net forward time {}".format(
    origin_forward_time, pruned_forward_time
))
print("origin net GFLOPS {}   vs prune net GFLOPS {}".format(
    origin_flops, pruned_flops
Exemplo n.º 2
0
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % args.print_frequence == args.print_frequence - 1 or args.print_frequence == trainloader.__len__() - 1:
            print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
                train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
        lr_scheduler.step()


if __name__ == '__main__':
    for epoch in range(start_epoch, args.epochs):
        train(epoch)
        test(epoch)
        state = {
            'net': net.state_dict(),
            'acc': accuracy
        }
        torch.save(state, '%s/ckpt%d_%d.t7' % (args.s, args.i % args.num, epoch))
        if args.level == 'filter':
            BN(epoch)
        elif args.level == 'layer':
            grafting(epoch)
    torch.save(accuracy, '%s/accuracy%d.t7' % (args.s, args.i))
Exemplo n.º 3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch Cifar10 LeNet Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1e-5, metavar='LR',
                        help='learning rate (default: 1)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--resume', type=str, default=None, metavar='RESUME',
                        help='Resume model from checkpoint')
    parser.add_argument('--T', type=int, default=60, metavar='N',
                        help='SNN time window')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True},
                     )
    mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
    std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=6),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        AddGaussianNoise(std=0.01)
        ])
    im_aug = transforms.Compose([
        #transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
        transforms.RandomRotation(10),
        transforms.RandomCrop(32, padding = 6),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        AddGaussianNoise(std=0.01)
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
        #transforms.Normalize(mean, std)
        ])

    trainset = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    
    for i in range(100):
        trainset = trainset + datasets.CIFAR10(root='./data', train=True, download=True, transform=im_aug)
        
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True)

    testset = datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False)

    snn_dataset = SpikeDataset(testset, T = args.T)
    snn_loader = torch.utils.data.DataLoader(snn_dataset, batch_size=10, shuffle=False)

    from models.vgg import VGG, CatVGG

    model = VGG('VGG19', clamp_max=1, quantize_bit=32).to(device)
    snn_model = CatVGG('VGG19', args.T).to(device)
    if args.resume != None:
        model.load_state_dict(torch.load(args.resume), strict=False)
        load_model(torch.load(args.resume), model)
        load_model(torch.load(args.resume), snn_model)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        #test(model, device, train_loader)
        test(model, device, test_loader)

        #transfer_model(model, snn_model)
        #test(snn_model, device, snn_loader)
        if args.save_model:
            torch.save(model.state_dict(), "cifar_cnn_19.pt")
        
        scheduler.step()
    #test(model, device, train_loader)
    test(model, device, test_loader)
    transfer_model(model, snn_model)
    with torch.no_grad():
        normalize_weight(snn_model.features, quantize_bit=8)
    test(snn_model, device, snn_loader)
    if args.save_model:
        torch.save(model.state_dict(), "cifar_cnn_19.pt")
Exemplo n.º 4
0
class DPP(object):
    def __init__(self, args):
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.lr = args.lr
        self.epochs = args.epochs
        self.save_dir = './' + args.save_dir  #later change
        if (os.path.exists(self.save_dir) == False):
            os.mkdir(self.save_dir)

        if (args.model == 'vgg16'):
            self.model = VGG('VGG16', 0)
            self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             lr=self.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay)
            self.model = torch.nn.DataParallel(self.model)
            self.model.cuda()
        elif (args.model == 'dpp_vgg16'):
            self.model = integrated_kernel(args)
            self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             lr=self.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay)

        #Parallel
        num_params = sum(p.numel() for p in self.model.parameters()
                         if p.requires_grad)
        print('The number of parametrs of models is', num_params)

        if (args.save_load):
            location = args.save_location
            print("locaton", location)
            checkpoint = torch.load(location)
            self.model.load_state_dict(checkpoint['state_dict'])

    def train(self, train_loader, test_loader, graph):
        #Declaration Model
        self.model.train()
        best_prec = 0
        losses = AverageMeter()
        top1 = AverageMeter()
        for epoch in range(self.epochs):
            #Test Accuarcy
            #self.adjust_learning_rate(epoch)
            for k, (inputs, target) in enumerate(train_loader):
                target = target.cuda(async=True)
                input_var = inputs.cuda()
                target_var = target
                output = self.model(input_var)
                loss = self.criterion(output, target_var)
                #Compute gradient and Do SGD step
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                #Measure accuracy and record loss
                prec1 = self.accuracy(output.data, target)[0]
                losses.update(loss.item(), inputs.size(0))
                top1.update(prec1.item(), inputs.size(0))

            graph.train_loss(losses.avg, epoch, 'train_loss')
            graph.train_acc(top1.avg, epoch, 'train_acc')
            prec = self.test(test_loader, epoch, graph)
            if (prec > best_prec):
                print("Acc", prec)
                best_prec = prec
                self.save_checkpoint(
                    {
                        'best_prec1': best_prec,
                        'state_dict': self.model.state_dict(),
                    },
                    filename=os.path.join(self.save_dir,
                                          'checkpoint_{}.tar'.format(epoch)))

    def test(self, test_loader, epoch, test_graph):
        self.model.eval()
        losses = AverageMeter()
        top1 = AverageMeter()
        for k, (inputs, target) in enumerate(test_loader):
            target = target.cuda()
            inputs = inputs.cuda()
            #Calculate each model
            #Compute gradient and Do SGD step
            output = self.model(inputs)
            loss = self.criterion(output, target)
            #Measure accuracy and record loss
            prec1 = self.accuracy(output.data, target)[0]
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
        test_graph.test_loss(losses.avg, epoch, 'test_loss')
        test_graph.test_acc(top1.avg, epoch, 'test_acc')
        return top1.avg

    def accuracy(self, output, target, topk=(1, )):
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    def adjust_learning_rate(self, epoch):
        self.lr = self.lr * (0.1**(epoch // 90))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def save_checkpoint(self, state, filename='checkpoint.pth.tar'):
        torch.save(state, filename)
Exemplo n.º 5
0
        plot_line(plt_x,
                  loss_rec["train"],
                  plt_x,
                  loss_rec["valid"],
                  mode="loss",
                  out_dir=log_dir)
        plot_line(plt_x,
                  acc_rec["train"],
                  plt_x,
                  acc_rec["valid"],
                  mode="acc",
                  out_dir=log_dir)

        if epoch > (args.max_epoch / 2) and best_acc < acc_valid:
            best_acc = acc_valid
            best_epoch = epoch

            checkpoint = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch,
                "best_acc": best_acc
            }

            path_checkpoint = os.path.join(log_dir, "checkpoint_best.pkl")
            torch.save(checkpoint, path_checkpoint)

    print(" done ~~~~ {}, best acc: {} in :{}".format(
        datetime.strftime(datetime.now(), '%m-%d_%H-%M'), best_acc,
        best_epoch))
Exemplo n.º 6
0
            _, pred = output.max(1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
#            correct += pred.eq(target).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))


best_prec1 = 0.
for epoch in range(epochs):
    if epoch in [epochs * 0.4, epochs * 0.7]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1
    train(epoch)

    prec1 = test()
    if prec1 > best_prec1:
        best_prec1 = prec1
        state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }
        torch.save(state, model_save_path)
        print("Best accuracy: " + str(best_prec1))