for seed in range(10):
        args.seed = seed
        # Model name
        model_name = args.outpath + 'test_domain_' + str(args.list_test_domain[0]) + '_vae_seed_' + str(
            args.seed)
        print(model_name)

        # Set seed
        torch.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = False
        np.random.seed(args.seed)

        # Load supervised training
        train_loader = data_utils.DataLoader(
            MnistRotated(args.list_train_domains, args.list_test_domain, args.num_supervised, args.seed, './../dataset/',
                         train=True),
            batch_size=args.batch_size,
            shuffle=True, **kwargs)

        # Load test
        test_loader = data_utils.DataLoader(
            MnistRotated(args.list_train_domains, args.list_test_domain, args.num_supervised, args.seed, './../dataset/',
                         train=False),
            batch_size=args.batch_size,
            shuffle=True, **kwargs)

        # setup the VAE
        model = VAE(args).to(device)

        # setup the optimizer
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
Esempio n. 2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=100, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=500, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--lamb', type=float, default=1.0, metavar='L',
                        help='weight for domain loss')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=0, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--list_train_domains', type=list, default=['0', '15', '45', '60'],
                        help='domains used during training')
    parser.add_argument('--list_test_domain', type=str, default='75',
                        help='domain used during testing')
    parser.add_argument('--num-supervised', default=1000, type=int,
                        help="number of supervised examples, /10 = samples per class")

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)

    model = CNNModel().to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    args.list_test_domain = [args.list_test_domain]

    # Choose training domains
    all_training_domains = ['0', '15', '30', '45', '75']
    all_training_domains.remove(args.list_test_domain[0])
    args.list_train_domains = all_training_domains

    print(args.list_test_domain, args.list_train_domains)

    # Load supervised training
    train_loader = data_utils.DataLoader(
        MnistRotated(args.list_train_domains, args.list_test_domain, args.num_supervised, args.seed, './../../dataset/',
                     train=True),
        batch_size=args.batch_size,
        shuffle=True, **kwargs)

    print(len(train_loader.dataset))

    # Load test
    test_loader = data_utils.DataLoader(
        MnistRotated(args.list_train_domains, args.list_test_domain, args.num_supervised, args.seed, './../../dataset/',
                     train=False),
        batch_size=args.batch_size,
        shuffle=True, **kwargs)

    train(args, model, device, train_loader, test_loader, optimizer)
Esempio n. 3
0
    test_accuracy_y_list = []

    for i in range(10):
        model_name = 'test_domain_75_sup_only_seed_' + str(i) + '_add_30'
        model = torch.load(model_name + '.model')
        args = torch.load(model_name + '.config')

        args.cuda = not args.no_cuda and torch.cuda.is_available()
        device = torch.device("cuda" if args.cuda else "cpu")
        kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {}

        # Load test
        test_loader_sup = data_utils.DataLoader(MnistRotated(
            args.list_train_domains,
            args.list_test_domain,
            args.num_supervised,
            args.seed,
            './../../../../dataset/',
            train=False),
                                                batch_size=args.batch_size,
                                                shuffle=True)

        # Set seed
        torch.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = False
        np.random.seed(args.seed)

        test_accuracy_d, test_accuracy_y = get_accuracy(
            test_loader_sup, model.classifier, args.batch_size)
        test_accuracy_y_list.append(test_accuracy_y)
Esempio n. 4
0
    print(args.list_test_domain, args.list_train_domains)

    # Set seed
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)

    # Empty data loader dict
    data_loaders = {}

    # Load supervised training
    train_loader_sup = data_utils.DataLoader(MnistRotated(
        args.list_train_domains,
        args.list_test_domain,
        args.num_supervised,
        args.seed,
        './../dataset/',
        train=True),
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             **kwargs)

    if args.seed < 9:
        additional_data_index = args.seed + 1
    elif args.seed == 9:
        additional_data_index = 0

    # Load unsupervised training
    train_loader_unsup = data_utils.DataLoader(MnistRotated(
        args.list_train_domains,