Exemplo n.º 1
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='PyTorch Cifar10 LeNet Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        metavar='LR',
                        help='learning rate (default: 1)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run',
                        action='store_true',
                        default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument('--resume',
                        type=str,
                        default=None,
                        metavar='RESUME',
                        help='Resume model from checkpoint')
    parser.add_argument('--T',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='SNN time window')
    parser.add_argument('--k',
                        type=int,
                        default=100,
                        metavar='N',
                        help='Data augmentation')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({
            'num_workers': 1,
            'pin_memory': True,
            'shuffle': True
        }, )

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=6),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        AddGaussianNoise(std=0.01)
    ])

    transform_test = transforms.Compose([transforms.ToTensor()])

    trainset = datasets.CIFAR10(root='./data',
                                train=True,
                                download=True,
                                transform=transform_train)
    train_loader_ = torch.utils.data.DataLoader(trainset,
                                                batch_size=512,
                                                shuffle=True)

    for i in range(args.k):

        im_aug = transforms.Compose([
            transforms.RandomRotation(10),
            transforms.RandomCrop(32, padding=6),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            AddGaussianNoise(std=0.01)
        ])
        trainset = trainset + datasets.CIFAR10(
            root='./data', train=True, download=True, transform=im_aug)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=256 + 512,
                                               shuffle=True)

    testset = datasets.CIFAR10(root='./data',
                               train=False,
                               download=False,
                               transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=100,
                                              shuffle=False)

    snn_dataset = SpikeDataset(testset, T=args.T)
    snn_loader = torch.utils.data.DataLoader(snn_dataset,
                                             batch_size=10,
                                             shuffle=False)

    from models.vgg import VGG, VGG_, CatVGG, CatVGG_

    model = VGG('VGG16', clamp_max=1, quantize_bit=32, bias=False).to(device)
    snn_model = CatVGG('VGG16', args.T, bias=True).to(device)

    #Trainable pooling
    #model = VGG_('VGG19_', clamp_max=1, quantize_bit=32,bias =True).to(device)
    #snn_model = CatVGG_('VGG19_', args.T,bias =True).to(device)

    if args.resume != None:
        model.load_state_dict(torch.load(args.resume), strict=False)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    correct_ = 0
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, train_loader_)
        correct = test(model, device, test_loader)
        if correct > correct_:
            correct_ = correct
        scheduler.step()

    model = fuse_bn_recursively(model)
    transfer_model(model, snn_model)
    with torch.no_grad():
        normalize_weight(snn_model.features, quantize_bit=32)
    test(snn_model, device, snn_loader)
Exemplo n.º 2
0
    parser.add_argument('--min_max_ratio', type=float, default=0.01)
    parser.add_argument(
        '--pretrain_dir',
        type=str,
        default='./data/experiments/pretrained/vgg6_cifar100_classif_80.pth')
    parser.add_argument('--dataset_root',
                        type=str,
                        default='./data/datasets/CIFAR/')
    parser.add_argument('--seed', default=1, type=int)

    args = parser.parse_args()
    args.cuda = torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")
    seed_torch(args.seed)
    model = VGG(n_layer='5+1', out_dim=80).to(device)
    model.load_state_dict(torch.load(args.pretrain_dir), strict=False)
    model.last = Identity()

    val_loader = CIFAR100Loader(root=args.dataset_root,
                                batch_size=args.batch_size,
                                split='train',
                                labeled=True,
                                aug=None,
                                shuffle=True,
                                mode='probe')
    eval_loader = CIFAR100Loader(root=args.dataset_root,
                                 batch_size=args.batch_size,
                                 split='train',
                                 labeled=False,
                                 aug=None,
                                 shuffle=False)
Exemplo n.º 3
0
    # model = ResNet(depth=args.depth)
    model = ResNet(depth=args.depth)

if torch.cuda.is_available():
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=momentum, weight_decay=weight_decay)

# load pre weight
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(args.resume, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))


# additional penalty term of sparsity-induced on bn weights => 0
def updateBN():
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.weight.grad.data.add_(args.s * torch.sign(m.weight.data))  # L1


def train(epoch):
Exemplo n.º 4
0
elif opt.model == 'resnet':
    model_ft = resnet50(pretrained=False, num_classes=10)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, 10)
    net = model_ft.cuda()
    #net = nn.DataParallel(model_ft.cuda())
elif opt.model == 'wide_resnet':
    from models.wideresnet import *
    net = nn.DataParallel(WideResNet().cuda())
    net = model_ft.cuda()
else:
    raise NotImplementedError('Invalid model')

if opt.resume and opt.resume_from:
    print(f'==> Resume from {opt.resume_from}')
    net.load_state_dict(torch.load(opt.resume_from))

net = nn.DataParallel(net)

#cudnn.benchmark = True

# Loss function
criterion = nn.CrossEntropyLoss()


# label smoothing
def LabelSmoothingLoss(outputs, targets):
    eps = 0.1
    batch_size, n_class = outputs.size()
    one_hot = torch.zeros_like(outputs).scatter(1, targets.view(-1, 1), 1)
    one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
if args.arch == "vgg":
    model = VGG(depth=args.depth, slim_channel=args.slim_channel, cfg=cfg)
else:
    # ResNet doesn't support slim channel strategy
    model = ResNet(depth=args.depth)

fuse_model_vgg(model)
optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=momentum,
                      weight_decay=weight_decay)

if torch.cuda.is_available():
    model.cuda()

model.load_state_dict(checkpoint["state_dict"])


def train(epoch):
    model.train()
    train_loss = 0
    total = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        pred = output.max(1, keepdim=True)[1]
        loss.backward()
Exemplo n.º 6
0
            # net = MobileNet()
            # net = MobileNetV2()
            # net = DPN92()
            # net = ShuffleNetG2()
            # net = SENet18()
            net = net.to(device)
            if device == 'cuda':
                net = torch.nn.DataParallel(net)
                cudnn.benchmark = True

            if args.resume:
                # Load checkpoint.
                print('==> Resuming from checkpoint..')
                assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
                checkpoint = torch.load('./checkpoint/'+exp_name+'.t7')
                net.load_state_dict(checkpoint['net'])
                best_acc = checkpoint['acc']
                start_epoch = checkpoint['epoch']

            criterion = nn.CrossEntropyLoss()
            optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

            train_loss = []
            train_acc = []

            val_loss = []
            val_acc = []



            for epoch in range(start_epoch, start_epoch+200):