def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='total batch size for training (default: 100)')
    parser.add_argument('--labeled-batch-size',
                        type=int,
                        default=50,
                        metavar='N',
                        help='labeled input batch size (default: 50)')
    parser.add_argument('--n-labeled',
                        type=int,
                        default=50,
                        metavar='N',
                        help='number of labelled data (default: 50)')
    parser.add_argument('--alpha',
                        type=float,
                        default=1,
                        metavar='ALPHA',
                        help='Hyperparameter for the loss (default: 1)')
    parser.add_argument(
        '--beta',
        type=float,
        default=1,
        metavar='BETA',
        help='Hyperparameter for the distance to weight function (default: 1.0)'
    )
    parser.add_argument('--pure',
                        default=False,
                        type=str2bool,
                        metavar='BOOL',
                        help='Is the unlabelled data pure')
    parser.add_argument('--weights',
                        default='none',
                        choices=['encoding', 'raw', 'none'],
                        type=str,
                        metavar='S',
                        help='What weights to use.')
    parser.add_argument('--encoder',
                        default=None,
                        type=str,
                        metavar='S',
                        help='File name for the pretrained autoencoder.')
    parser.add_argument('--dim',
                        type=int,
                        default=20,
                        metavar='N',
                        help='The dimension of the encoding.')
    parser.add_argument('--output',
                        default='default_ouput.csv',
                        type=str,
                        metavar='S',
                        help='File name for the output.')
    parser.add_argument(
        '--exclude-unlabeled',
        default=False,
        type=str2bool,
        metavar='BOOL',
        help='exclude unlabeled examples from the training set')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        metavar='M',
                        help='SGD momentum (default: 0.0)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        metavar='GAMMA',
                        help='Gamma for learning rate decay (default: 1.0)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        metavar='S',
                        help='random seed (default: 0)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--runs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='Number of runs')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    torch.manual_seed(args.seed)  # set seed for pytorch
    use_cuda = torch.cuda.is_available()

    wanted_classes = {0, 1, 2, 3, 4}
    args.num_classes = len(wanted_classes)

    folder = os.path.expanduser('./mnist_results')
    try:
        os.makedirs(folder)
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            raise
    output_path = os.path.join(folder, args.output)

    for seed in range(
            args.runs
    ):  # seed for creating labelled and unlabelled data training data.

        device = torch.device("cuda" if use_cuda else "cpu")

        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

        train_dataset = mnist.MNIST('../mnist',
                                    dataset='train',
                                    weights=args.weights,
                                    encoder=args.encoder,
                                    n_labeled=args.n_labeled,
                                    wanted_classes=wanted_classes,
                                    pure=args.pure,
                                    download=True,
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.1307, ),
                                                             (0.3081, ))
                                    ]),
                                    seed=seed,
                                    alpha=args.beta,
                                    func='exp',
                                    dim=args.dim)

        if args.exclude_unlabeled:
            sampler = SubsetRandomSampler(range(args.n_labeled))
            batch_sampler = BatchSampler(sampler,
                                         args.batch_size,
                                         drop_last=False)
        else:
            batch_sampler = data.TwoStreamBatchSampler(
                range(args.n_labeled,
                      len(train_dataset)), range(args.n_labeled),
                args.batch_size, args.labeled_batch_size)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_sampler=batch_sampler,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            mnist.MNIST('../mnist',
                        dataset='test',
                        wanted_classes=wanted_classes,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307, ), (0.3081, ))
                        ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

        model = Net(args.num_classes).to(device)
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum)
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                        gamma=args.gamma)
        for epoch in range(1, args.epochs + 1):
            alpha = utils.alpha_ramp_up(
                args.alpha, epoch, 2, 30
            )  #######################################################################
            if not args.exclude_unlabeled:
                pseudo_label.assign_labels(
                    args, model, device, train_dataset,
                    range(args.n_labeled, len(train_dataset)))

            start = time.time()
            pseudo_label.train(args, model, device, train_loader, optimizer,
                               epoch, alpha)
            print('\nTraining one epoch took: {:.4f} seconds.\n'.format(
                time.time() - start))
            accuracy = pseudo_label.test(args, model, device, test_loader)
            lr_scheduler.step()

        with open(output_path, 'a') as writeFile:
            writer = csv.writer(writeFile)
            writer.writerow([seed, accuracy])

        if (args.save_model):
            torch.save(model.state_dict(), "mnist_cnn.pt")
Пример #2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='total batch size for training (default: 100)')
    parser.add_argument('--labeled-batch-size',
                        type=int,
                        default=50,
                        metavar='N',
                        help='labeled input batch size (default: 50)')
    parser.add_argument('--n-labeled',
                        type=int,
                        default=3000,
                        metavar='N',
                        help='number of labelled data (default: 3000)')
    parser.add_argument('--alpha',
                        type=float,
                        default=0,
                        metavar='ALPHA',
                        help='Hyperparameter for the loss (default: 0)')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=0.999,
                        metavar='DECAY',
                        help='EMA model decay rate')
    parser.add_argument(
        '--beta',
        type=float,
        default=1,
        metavar='BETA',
        help='Hyperparameter for the distance to weight function (default: 1.0)'
    )
    parser.add_argument('--pure',
                        default=False,
                        type=str2bool,
                        metavar='BOOL',
                        help='Is the unlabelled data pure')
    parser.add_argument('--weights',
                        default='none',
                        choices=['encoding', 'raw', 'none'],
                        type=str,
                        metavar='S',
                        help='What weights to use.')
    parser.add_argument('--consistency-type',
                        default='mse',
                        choices=['mse', 'kl'],
                        type=str,
                        metavar='S',
                        help='Type of consistency loss function to use.')
    parser.add_argument('--encoder',
                        default=None,
                        type=str,
                        metavar='S',
                        help='File name for the pretrained autoencoder.')
    parser.add_argument('--output',
                        default='default_ouput.csv',
                        type=str,
                        metavar='S',
                        help='File name for the output.')
    parser.add_argument('--rampup-period',
                        type=int,
                        default=10,
                        metavar='N',
                        help='The length of the ramp up period for alpha.')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        metavar='N',
                        help='number of epochs to train (default: 50)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        metavar='M',
                        help='SGD momentum (default: 0.0)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        metavar='GAMMA',
                        help='Gamma for learning rate decay (default: 1.0)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        metavar='S',
                        help='random seed (default: 0)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--runs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='Number of runs')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument('--validation',
                        default=False,
                        type=str2bool,
                        metavar='BOOL',
                        help='Is the unlabelled data pure')

    args = parser.parse_args()
    torch.manual_seed(args.seed)  # set seed for pytorch
    use_cuda = torch.cuda.is_available()

    validation_test = 'validation' if args.validation else 'test'
    wanted_classes = {
        0, 1, 2, 3, 4
    }  # only care about the first 5 classes, the rest are considered as unanticipated classes
    args.num_classes = len(wanted_classes)

    folder = os.path.expanduser('./meanteacher_cifar10_results')
    try:
        os.makedirs(folder)
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            raise
    output_path = os.path.join(folder, args.output)

    for seed in range(
            args.runs
    ):  # seed for creating labelled and unlabelled data training data.

        device = torch.device("cuda" if use_cuda else "cpu")

        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        train_dataset = cifar10.CIFAR10('../cifar10',
                                        dataset='train',
                                        weights=args.weights,
                                        encoder=args.encoder,
                                        n_labeled=args.n_labeled,
                                        wanted_classes=wanted_classes,
                                        pure=args.pure,
                                        download=True,
                                        transform=transforms.Compose(
                                            [transforms.ToTensor()]),
                                        seed=seed,
                                        alpha=args.beta,
                                        func='exp')

        batch_sampler = data.TwoStreamBatchSampler(
            range(args.n_labeled, len(train_dataset)), range(args.n_labeled),
            args.batch_size, args.labeled_batch_size)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_sampler=batch_sampler,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            cifar10.CIFAR10('../cifar10',
                            dataset=validation_test,
                            wanted_classes=wanted_classes,
                            transform=transforms.Compose(
                                [transforms.ToTensor()])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

        model = Cifar10CNN(args.num_classes).to(device)
        ema_model = Cifar10CNN(args.num_classes).to(device)
        ema_model = mean_teacher.detach_model_params(
            ema_model)  # stop backprop for the ema_model (teacher model)

        # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
        optimizer = optim.RMSprop(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum)
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                        gamma=args.gamma)
        global_step = 0

        for epoch in range(1, args.epochs + 1):
            alpha = mean_teacher_alpha_rampup(args.alpha, epoch,
                                              args.rampup_period)
            start = time.time()
            global_step = mean_teacher.train(args, model, ema_model, device,
                                             train_loader, optimizer, epoch,
                                             alpha, global_step)
            print('\nTraining one epoch took: {:.4f} seconds.\n'.format(
                time.time() - start))
            accuracy = mean_teacher.test(args, model, device, test_loader)
            lr_scheduler.step()

        with open(output_path, 'a') as writeFile:
            writer = csv.writer(writeFile)
            writer.writerow([seed, accuracy])

        if (args.save_model):
            torch.save(model.state_dict(), "meanteacher_cifar10_model.pt")