Esempio n. 1
0
if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset_train = torchvision.datasets.CIFAR10(root="./data",
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
    dataset_valid = torchvision.datasets.CIFAR10(root="./data",
                                                 train=False,
                                                 download=True,
                                                 transform=transform)

    net = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    trainer = DartsTrainer(net,
                           loss=criterion,
                           metrics=accuracy,
                           optimizer=optimizer,
                           num_epochs=2,
                           dataset_train=dataset_train,
                           dataset_valid=dataset_valid,
                           batch_size=64,
                           log_frequency=10)
    trainer.enable_visualization()
    trainer.train()
    trainer.export("checkpoint.json")
Esempio n. 2
0
def main(args):
    reset_seed(args.seed)
    prepare_logger(args)

    logger.info("These are the hyper-parameters you want to tune:\n%s",
                pprint.pformat(vars(args)))

    if args.model == 'nas':
        logger.info("Using NAS.\n")
        if args.fix_arch:
            if not os.path.exists(args.arc_checkpoint):
                print(args.arc_checkpoint,
                      'does not exist, don not fix archetect')
                args.fix_arch = False

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if args.model == 'nas':
        if not args.fix_arch:
            model = CNN(32, 3, args.channels, 10, args.layers)
            trainset, testset = data_preprocess(args)
        else:
            model = CNN(32, 3, args.channels, 10, args.layers)
            apply_fixed_architecture(model, args.arc_checkpoint)
            model.to(device)
            train_loader, test_loader = data_preprocess(args)
    else:
        train_loader, test_loader = data_preprocess(args)
        model = models.__dict__[args.model]()
        model.to(device)

    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.initial_lr,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'sgd':
            optimizer_cls = optim.SGD
        elif args.optimizer == 'rmsprop':
            optimizer_cls = optim.RMSprop
        optimizer = optimizer_cls(model.parameters(),
                                  lr=args.initial_lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     args.epochs,
                                                     eta_min=args.ending_lr)

    if args.model == 'nas' and not args.fix_arch:
        trainer = DartsTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracyTopk(
                                   output, target, topk=(1, )),
                               optimizer=optimizer,
                               num_epochs=args.epochs,
                               dataset_train=trainset,
                               dataset_valid=testset,
                               batch_size=args.batch_size,
                               log_frequency=args.log_frequency,
                               unrolled=args.unrolled,
                               callbacks=[
                                   LRSchedulerCallback(scheduler),
                                   ArchitectureCheckpoint("./checkpoints")
                               ])
        if args.visualization:
            trainer.enable_visualization()
        trainer.train()
        trainer.export("final_arch.json")
    else:
        for epoch in range(1, args.epochs + 1):
            train(model, train_loader, criterion, optimizer, scheduler, args,
                  epoch, device)
            top1, _ = test(model, test_loader, criterion, args, epoch, device)
            nni.report_intermediate_result(top1)
        logger.info("Final accuracy is: %.6f", top1)
        nni.report_final_result(top1)
Esempio n. 3
0
    model = CNN(32, 3, args.channels, 6, args.layers)
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(),
                            0.025,
                            momentum=0.9,
                            weight_decay=3.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                              args.epochs,
                                                              eta_min=0.001)

    trainer = DartsTrainer(
        model,
        loss=criterion,
        metrics=lambda output, target: accuracy(output, target, topk=(1, )),
        optimizer=optim,
        num_epochs=args.epochs,
        dataset_train=dataset_train,
        dataset_valid=dataset_valid,
        batch_size=args.batch_size,
        log_frequency=args.log_frequency,
        unrolled=args.unrolled,
        callbacks=[
            LRSchedulerCallback(lr_scheduler),
            ArchitectureCheckpoint("./checkpoints_aoi")
        ])
    if args.visualization:
        trainer.enable_visualization()
    trainer.train()
    trainer.export(file="final_architecture.json")