예제 #1
0
    # ready to go
    for epoch in range(args.epochs):
        model.train()

        if epoch in decreasing_lr:
            grad_scale = grad_scale / 8.0

        logger("training phase")
        for batch_idx, (data, target) in enumerate(train_loader):
            indx_target = target.clone()
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            optimizer.zero_grad()
            output = model(data)
            loss = wage_util.SSE(output, target)

            loss.backward()

            for name, param in list(model.named_parameters())[::-1]:
                param.grad.data = wage_quantizer.QG(param.grad.data,
                                                    args.wl_grad, grad_scale)

            optimizer.step()

            for name, param in list(model.named_parameters())[::-1]:
                param.data = wage_quantizer.C(param.data, args.wl_weight)

            if batch_idx % args.log_interval == 0 and batch_idx > 0:
                pred = output.data.max(1)[
                    1]  # get the index of the max log-probability
예제 #2
0
    train_loader, test_loader = dataset.get_cifar10(batch_size=args.batch_size,
                                                    num_workers=1)
elif args.dataset == 'cifar100':
    train_loader, test_loader = dataset.get_cifar100(
        batch_size=args.batch_size, num_workers=1)
elif args.dataset == 'imagenet':
    train_loader, test_loader = dataset.get_imagenet(
        batch_size=args.batch_size, num_workers=1)
else:
    raise ValueError("Unknown dataset type")

assert args.model in ['VGG8', 'DenseNet40', 'ResNet18'], args.model
if args.model == 'VGG8':
    from models import VGG
    model = VGG.vgg8(args=args, logger=logger)
    criterion = wage_util.SSE()
elif args.model == 'DenseNet40':
    from models import DenseNet
    model = DenseNet.densenet40(args=args, logger=logger)
    criterion = wage_util.SSE()
elif args.model == 'ResNet18':
    from models import ResNet
    model = ResNet.resnet18(args=args, logger=logger)
    criterion = torch.nn.CrossEntropyLoss()
else:
    raise ValueError("Unknown model type")

if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(),