示例#1
0
def main():
    train_loader, test_loader = cifar10_loaders(args.batch_size)

    if args.model == "resnet":
        model = resnet20()
    elif args.model == "senet":
        model = se_resnet20(num_classes=10, reduction=args.reduction)
    elif args.model == "gcn":
        model = resnet20_gcn()
    else:
        raise TypeError(f"{args.model} is not valid argument")

    optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=1e-4)
    scheduler = lr_scheduler.StepLR(80, 0.1)
    tqdm_rep = reporter.TQDMReporter(range(args.epochs),
                                     callbacks=[callbacks.AccuracyCallback()],
                                     save_dir='logs/',
                                     report_freq=-1)
    # tb_rep = reporter.TensorboardReporter(callbacks=[callbacks.AccuracyCallback(), callbacks.LossCallback()], save_dir='logs/')
    trainer = Trainer(model,
                      optimizer,
                      F.cross_entropy,
                      scheduler=scheduler,
                      callbacks=tqdm_rep)
    for _ in tqdm_rep:
        trainer.train(train_loader)
        trainer.test(test_loader)
示例#2
0
def main():
    train_loader, test_loader = cifar10_loaders(args.batch_size)
    print(test_loader.sampler)

    if args.baseline:
        model = resnet20()
    else:
        model = se_resnet20(num_classes=10, reduction=args.reduction)
    optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=1e-4)
    scheduler = lr_scheduler.StepLR(80, 0.1)
    tqdm_rep = reporters.TQDMReporter(range(args.epochs), callbacks.AccuracyCallback())
    _callbacks = [tqdm_rep, callbacks.AccuracyCallback()]
    with Trainer(model, optimizer, F.cross_entropy, scheduler=scheduler, callbacks=_callbacks) as trainer:
        for _ in tqdm_rep:
            trainer.train(train_loader)
            trainer.test(test_loader)
示例#3
0
def main():

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Assuming that we are on a CUDA machine, this should print a CUDA device:

    print(device)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=True,
                                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data',
                                           train=False,
                                           download=True,
                                           transform=transform)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=4,
                                             shuffle=False,
                                             num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # if args.baseline:
    model_baseline = resnet20().to(device)
    # else:
    model = se_resnet20(num_classes=10, reduction=args.reduction).to(device)

    model_baseline.load_state_dict(torch.load('checkpoint/cifar_baseline.pth'))
    for name, value in model.named_parameters():
        if name in model_baseline.state_dict():
            value.data = model_baseline.state_dict()[name].detach().clone()
            value.requires_grad = False

    test(model, testloader)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=0.001,
                          momentum=0.9)

    for epoch in range(20):  # loop over the dataset multiple times
        print(f"Start training epoch {epoch}")
        running_loss = 0.0
        print(len(trainloader))
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

        test(model, testloader)

    print('Finished Training')
    PATH = './cifar_transfer.pth'
    torch.save(model.state_dict(), PATH)