示例#1
0
def train(model, optimizer, criterion, trainloader, valloader, testloader, epochs, device, root):
    best_acc = 0.0
    supervised_loss = []

    for epoch in range(epochs):  # loop over the dataset multiple times
      for i, data in enumerate(trainloader):

        l_x, l_y = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(l_x)
        sup_loss = criterion(outputs, l_y)
        loss = sup_loss

        loss.backward()
        optimizer.step()

      # Calculating loss and accuracy
      vat_acc =  evaluate_classifier(model, valloader, device)
      print('Epoch: {}, Val_acc: {:.3} Sup_loss: {:.3}'.format(epoch, vat_acc, sup_loss.item()))

      supervised_loss.append(sup_loss.item())

      if (vat_acc > best_acc):
        loadsave(model, optimizer, "Lenet", root=root, mode='save')
        best_acc = vat_acc

    return supervised_loss
def main(args):
    transform_SVHN = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset_SVHN = torchvision.datasets.SVHN(root=args.dataset_path[0], split='train', download=True, transform=transform_SVHN)
    testset_SVHN = torchvision.datasets.SVHN(root=args.dataset_path[0], split='test', download=True, transform=transform_SVHN)

    train_labelled_size = int(0.6 * len(trainset_SVHN))
    train_unlabelled_size = len(trainset_SVHN) - train_labelled_size
    val_size = int(0.2 * len(testset_SVHN))
    test_size = len(testset_SVHN) - val_size
    trainset_labelled, trainset_unlabelled = torch.utils.data.random_split(trainset_SVHN, [train_labelled_size, train_unlabelled_size])
    valset, testset = torch.utils.data.random_split(testset_SVHN, [val_size, test_size])

    # Should increase batch size to decrease training time. Batch size for LeNet and VAT datasets can be different, i.e. 32 for LeNet and 128 for VAT
    train_labelled_loader = DataLoader(trainset_labelled, batch_size=32, shuffle=True, num_workers=2)
    train_unlabelled_loader = DataLoader(trainset_unlabelled, batch_size=32, shuffle=True, num_workers=2)
    valloader = DataLoader(valset, batch_size=1, shuffle=True, num_workers=2)
    testloader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device: " + str(device))

    lenet0 = LeNet(device)
    lenet0 = lenet0.to(device)
    print(lenet0)

    criterion = nn.CrossEntropyLoss()
    criterion_VAT = VAT(device, eps=args.eps, xi=args.xi, k=args.k, use_entmin=args.use_entmin)
    optimizer = optim.Adam(lenet0.parameters(), lr=args.lr) # Should implement lr scheduler.
    # optimizer = optim.SGD(lenet0.parameters(), lr=args.lr, momentum=0.9)

    if args.eval_only:
        loadsave(lenet0, optimizer, "VATcheck", root=args.weights_path[0], mode='load')

    else:
        supervised_loss, unsupervised_loss = train(lenet0, optimizer, criterion, criterion_VAT, train_labelled_loader, train_unlabelled_loader, valloader, testloader, args.alpha, args.epochs, device, args.weights_path[0])

        plt.subplot(2,1,1)
        plt.plot(supervised_loss)
        plt.title("Supervised loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.grid(True)

        plt.subplot(2,1,2)
        plt.plot(unsupervised_loss)
        plt.title("Unsupervised loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.grid(True)

        plt.show()

        loadsave(lenet0, optimizer, "VATcheck", root=args.weights_path[0], mode='load')

    vat_acc =  evaluate_classifier(lenet0, testloader, device)
    print("Accuracy of the network on SVHN is %d%%\n" %(vat_acc*100))

    barchartplot(lenet0, testloader, device)
def train(model, optimizer, criterion, criterion_VAT, trainloader_SVHN,
          trainloader_MNIST, valloader, testloader_SVHN, alpha, epochs, device,
          root):
    best_acc = 0.0
    supervised_loss = []
    unsupervised_loss = []

    for epoch in range(epochs):  # loop over the dataset multiple times
        dataloader_iterator = iter(trainloader_MNIST)

        for i, data in enumerate(trainloader_SVHN):
            try:
                data2 = next(dataloader_iterator)
            except StopIteration:
                dataloader_iterator = iter(trainloader_MNIST)
                data2 = next(dataloader_iterator)

            l_x, l_y = data[0].to(device), data[1].to(device)
            ul_x, ul_y = data2[0].to(device), data2[1].to(device)
            optimizer.zero_grad()

            outputs = model(l_x)
            sup_loss = criterion(outputs, l_y)
            unsup_loss = alpha * criterion_VAT(model, ul_x)
            loss = sup_loss + unsup_loss

            loss.backward()
            optimizer.step()

        # Calculating loss and accuracy
        vat_acc, org_acc = evaluate_classifier(model, valloader,
                                               testloader_SVHN, device)
        print(
            'Epoch: {}, Val_acc: {:.3} Org_acc: {:.3} Sup_loss: {:.3} Unsup_loss: {:.3}'
            .format(epoch, vat_acc, org_acc, sup_loss.item(),
                    unsup_loss.item()))

        supervised_loss.append(sup_loss.item())
        unsupervised_loss.append(unsup_loss.item())

        # if (vat_acc > best_acc):
        #   loadsave(model, optimizer, "LenetVAT", root=root, mode='save')
        #   best_acc = vat_acc

    loadsave(model, optimizer, "LenetVAT", root=root, mode='save')
    return supervised_loss, unsupervised_loss
def train(model, optimizer, criterion, criterion_VAT, train_labelled_loader, train_unlabelled_loader, valloader, testloader, alpha, epochs, device, root):
    best_acc = 0.0
    supervised_loss = []
    unsupervised_loss = []

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 25], gamma=0.5)

    for epoch in range(epochs):  # loop over the dataset multiple times
        dataloader_iterator = iter(train_unlabelled_loader)

        for i, data in enumerate(train_labelled_loader):
            try:
                data2 = next(dataloader_iterator)
            except StopIteration:
                dataloader_iterator = iter(train_unlabelled_loader)
                data2 = next(dataloader_iterator)

            l_x, l_y = data[0].to(device), data[1].to(device)
            ul_x, ul_y = data2[0].to(device), data2[1].to(device)
            optimizer.zero_grad()

            outputs = model(l_x)
            sup_loss = criterion(outputs, l_y)
            unsup_loss = alpha * criterion_VAT(model, ul_x)
            loss = sup_loss + unsup_loss

            loss.backward()
            optimizer.step()

        scheduler.step()

            # Calculating loss and accuracy
        vat_acc =  evaluate_classifier(model, valloader, device)
        print('Epoch: {}, Val_acc: {:.3} Sup_loss: {:.3} Unsup_loss: {:.3}'.format(epoch, vat_acc, sup_loss.item(), unsup_loss.item()))

        supervised_loss.append(sup_loss.item())
        unsupervised_loss.append(unsup_loss.item())

        if (vat_acc > best_acc):
          loadsave(model, optimizer, "VATcheck", root=root, mode='save')
          best_acc = vat_acc

    return supervised_loss, unsupervised_loss
def main(args):
    transform_SVHN = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transform_MNIST = transforms.Compose([
        toRGB(),
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset_SVHN = torchvision.datasets.SVHN(root=args.dataset_path[0],
                                              split='train',
                                              download=True,
                                              transform=transform_SVHN)
    fullset_MNIST = torchvision.datasets.MNIST(root=args.dataset_path[0],
                                               train=True,
                                               download=True,
                                               transform=transform_MNIST)
    testset = torchvision.datasets.MNIST(root=args.dataset_path[0],
                                         train=False,
                                         download=True,
                                         transform=transform_MNIST)
    testset_SVHN = torchvision.datasets.SVHN(root=args.dataset_path[0],
                                             split='test',
                                             download=True,
                                             transform=transform_SVHN)

    train_size = int(0.8 * len(fullset_MNIST))
    val_size = len(fullset_MNIST) - train_size
    trainset_MNIST, valset = torch.utils.data.random_split(
        fullset_MNIST, [train_size, val_size])

    # Should increase batch size to decrease training time. Batch size for LeNet and VAT datasets can be different, i.e. 32 for LeNet and 128 for VAT
    trainloader_SVHN = DataLoader(trainset_SVHN,
                                  batch_size=32,
                                  shuffle=True,
                                  num_workers=2)
    trainloader_MNIST = DataLoader(trainset_MNIST,
                                   batch_size=32,
                                   shuffle=True,
                                   num_workers=2)
    valloader = DataLoader(valset, batch_size=1, shuffle=True, num_workers=2)
    testloader = DataLoader(testset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=2)
    testloader_SVHN = DataLoader(testset_SVHN,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device: " + str(device))

    lenet0 = LeNet(device)
    lenet0 = lenet0.to(device)
    print(lenet0)

    criterion = nn.CrossEntropyLoss()
    criterion_VAT = VAT(device,
                        eps=args.eps,
                        xi=args.xi,
                        k=args.k,
                        use_entmin=args.use_entmin)
    optimizer = optim.Adam(lenet0.parameters(),
                           lr=args.lr)  # Should implement lr scheduler.
    # optimizer = optim.SGD(lenet0.parameters(), lr=args.lr, momentum=0.9)

    if args.eval_only:
        loadsave(lenet0,
                 optimizer,
                 "LenetVAT",
                 root=args.weights_path[0],
                 mode='load')

    else:
        supervised_loss, unsupervised_loss = train(
            lenet0, optimizer, criterion, criterion_VAT, trainloader_SVHN,
            trainloader_MNIST, valloader, testloader_SVHN, args.alpha,
            args.epochs, device, args.weights_path[0])
        loss_plot(supervised_loss, unsupervised_loss)

    vat_acc, org_acc = evaluate_classifier(lenet0, testloader, testloader_SVHN,
                                           device)
    print(
        "Accuracy of the network on MNIST is %d%%\nAccuracy of the network on SVHN is %d%%\n"
        % (vat_acc * 100, org_acc * 100))

    barchartplot(lenet0, testloader, device)