예제 #1
0
def main(args):
    reset_seed(args.seed)
    prepare_logger(args)

    logger.info("These are the hyper-parameters you want to tune:\n%s",
                pprint.pformat(vars(args)))

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    train_loader, test_loader = data_preprocess(args)
    # model = models.__dict__[args.model](num_classes=10)
    model = CNN(32, 3, args.channels, 10, args.layers)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.initial_lr,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'sgd':
            optimizer_cls = optim.SGD
        elif args.optimizer == 'rmsprop':
            optimizer_cls = optim.RMSprop
        optimizer = optimizer_cls(model.parameters(),
                                  lr=args.initial_lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)

    if args.lr_scheduler == 'cosin':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.epochs, eta_min=args.ending_lr)
    elif args.lr_scheduler == 'linear':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=15,
                                              gamma=0.1)

    trainer = DartsTrainer(
        model,
        loss=criterion,
        metrics=lambda output, target: accuracy(output, target),
        optimizer=optimizer,
        num_epochs=args.epochs,
        dataset_train=train_loader,
        dataset_valid=test_loader,
        batch_size=args.batch_size,
        log_frequency=args.log_frequency,
        unrolled=args.unrolled,
        callbacks=[
            LRSchedulerCallback(scheduler),
            ArchitectureCheckpoint("./checkpoints_layer5")
        ])

    if args.visualization:
        trainer.enable_visualization()
    trainer.train()
 def generate_callbacks(self):
     '''
     Args:
         func: a function to generate other callbacks, must return a list
     Return:
         a list of callbacks.
     '''
     self.ckpt_callback = CheckpointCallback(
         checkpoint_dir=self.cfg.logger.path,
         name='best_search.pth',
         mode=self.cfg.callback.checkpoint.mode)
     self.arch_callback = ArchitectureCheckpoint(self.cfg.logger.path)
     self.relevance_callback = RelevanceCallback(
         save_path=self.cfg.logger.path,
         filename=self.cfg.callback.relevance.filename)
     callbacks = [
         self.ckpt_callback, self.arch_callback, self.relevance_callback
     ]
     return callbacks
예제 #3
0
                                   tanh_constant=1.1,
                                   cell_exit_extra_step=True)
    else:
        raise AssertionError

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                0.05,
                                momentum=0.9,
                                weight_decay=1.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                              T_max=num_epochs,
                                                              eta_min=0.001)

    trainer = enas.EnasTrainer(model,
                               loss=criterion,
                               metrics=accuracy,
                               reward_function=reward_accuracy,
                               optimizer=optimizer,
                               callbacks=[
                                   LRSchedulerCallback(lr_scheduler),
                                   ArchitectureCheckpoint("./checkpoints")
                               ],
                               batch_size=args.batch_size,
                               num_epochs=num_epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
                               log_frequency=args.log_frequency,
                               mutator=mutator)
    trainer.train()
예제 #4
0
    criterion = nn.CrossEntropyLoss()

    if args.arch is not None:
        logger.info('model retraining...')
        with open(args.arch, 'r') as f:
            arch = json.load(f)
        for trial in query_nb201_trial_stats(arch, 200, 'cifar100'):
            pprint.pprint(trial)
        apply_fixed_architecture(model, args.arch)
        dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=0)
        dataloader_valid = DataLoader(dataset_valid, batch_size=args.batch_size, shuffle=True, num_workers=0)
        train(args, model, dataloader_train, dataloader_valid, criterion, optim,
              torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        exit(0)

    trainer = enas.EnasTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracy(output, target, topk=(1,)),
                               reward_function=reward_accuracy,
                               optimizer=optim,
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
                               batch_size=args.batch_size,
                               num_epochs=args.epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
                               log_frequency=args.log_frequency)

    if args.visualization:
        trainer.enable_visualization()
    trainer.train()
예제 #5
0
    def model_creator(layers):
        model = CNN(32, 3, args.channels, 10, layers, n_nodes=args.nodes)
        criterion = nn.CrossEntropyLoss()

        optim = torch.optim.SGD(model.parameters(),
                                0.025,
                                momentum=0.9,
                                weight_decay=3.0E-4)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optim, args.epochs, eta_min=0.001)

        return model, criterion, optim, lr_scheduler

    logger.info("initializing trainer")
    trainer = PdartsTrainer(
        model_creator,
        init_layers=args.init_layers,
        metrics=lambda output, target: accuracy(output, target, topk=(1, )),
        pdarts_num_layers=args.add_layers,
        pdarts_num_to_drop=args.dropped_ops,
        num_epochs=args.epochs,
        dataset_train=dataset_train,
        dataset_valid=dataset_valid,
        batch_size=args.batch_size,
        log_frequency=args.log_frequency,
        unrolled=args.unrolled,
        callbacks=[ArchitectureCheckpoint("./checkpoints")])
    logger.info("training")
    trainer.train()
예제 #6
0
def main(args):
    reset_seed(args.seed)
    prepare_logger(args)

    logger.info("These are the hyper-parameters you want to tune:\n%s",
                pprint.pformat(vars(args)))

    if args.model == 'nas':
        logger.info("Using NAS.\n")
        if args.fix_arch:
            if not os.path.exists(args.arc_checkpoint):
                print(args.arc_checkpoint,
                      'does not exist, don not fix archetect')
                args.fix_arch = False

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if args.model == 'nas':
        if not args.fix_arch:
            model = CNN(32, 3, args.channels, 10, args.layers)
            trainset, testset = data_preprocess(args)
        else:
            model = CNN(32, 3, args.channels, 10, args.layers)
            apply_fixed_architecture(model, args.arc_checkpoint)
            model.to(device)
            train_loader, test_loader = data_preprocess(args)
    else:
        train_loader, test_loader = data_preprocess(args)
        model = models.__dict__[args.model]()
        model.to(device)

    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.initial_lr,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'sgd':
            optimizer_cls = optim.SGD
        elif args.optimizer == 'rmsprop':
            optimizer_cls = optim.RMSprop
        optimizer = optimizer_cls(model.parameters(),
                                  lr=args.initial_lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     args.epochs,
                                                     eta_min=args.ending_lr)

    if args.model == 'nas' and not args.fix_arch:
        trainer = DartsTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracyTopk(
                                   output, target, topk=(1, )),
                               optimizer=optimizer,
                               num_epochs=args.epochs,
                               dataset_train=trainset,
                               dataset_valid=testset,
                               batch_size=args.batch_size,
                               log_frequency=args.log_frequency,
                               unrolled=args.unrolled,
                               callbacks=[
                                   LRSchedulerCallback(scheduler),
                                   ArchitectureCheckpoint("./checkpoints")
                               ])
        if args.visualization:
            trainer.enable_visualization()
        trainer.train()
        trainer.export("final_arch.json")
    else:
        for epoch in range(1, args.epochs + 1):
            train(model, train_loader, criterion, optimizer, scheduler, args,
                  epoch, device)
            top1, _ = test(model, test_loader, criterion, args, epoch, device)
            nni.report_intermediate_result(top1)
        logger.info("Final accuracy is: %.6f", top1)
        nni.report_final_result(top1)
예제 #7
0
model = DartsStackedCells(3, args.channels, 10, args.layers)
model = model.to(device)
trainset, testset = getDatasets()

if __name__ == '__main__':
    criterion = nn.CrossEntropyLoss()
    # optimizer = optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
    optimizer = optim.Adamax(model.parameters(), lr=0.025)
    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                        args.epochs,
                                                        eta_min=0.001)

    trainer = DartsTrainer(
        model,
        criterion,
        lambda outputs, labels: getAccuracy(outputs, labels),
        optimizer,
        args.epochs,
        trainset,
        testset,
        batch_size=args.batch_size,
        log_frequency=args.log_frequency,
        unrolled=args.unrolled,
        callbacks=[
            LRSchedulerCallback(lr_scheduler),
            ArchitectureCheckpoint('./checkpoints')
        ])
    if args.visualization:
        trainer.enable_visualization()
    trainer.train()
예제 #8
0
파일: search.py 프로젝트: curiousjit/nni
    model = CNN(32, 3, args.channels, 6, args.layers)
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(),
                            0.025,
                            momentum=0.9,
                            weight_decay=3.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                              args.epochs,
                                                              eta_min=0.001)

    trainer = DartsTrainer(
        model,
        loss=criterion,
        metrics=lambda output, target: accuracy(output, target, topk=(1, )),
        optimizer=optim,
        num_epochs=args.epochs,
        dataset_train=dataset_train,
        dataset_valid=dataset_valid,
        batch_size=args.batch_size,
        log_frequency=args.log_frequency,
        unrolled=args.unrolled,
        callbacks=[
            LRSchedulerCallback(lr_scheduler),
            ArchitectureCheckpoint("./checkpoints_aoi")
        ])
    if args.visualization:
        trainer.enable_visualization()
    trainer.train()
    trainer.export(file="final_architecture.json")
예제 #9
0
    if args.v1:
        from nni.algorithms.nas.pytorch.darts import DartsTrainer
        trainer = DartsTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracy(
                                   output, target, topk=(1, )),
                               optimizer=optim,
                               num_epochs=args.epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
                               batch_size=args.batch_size,
                               log_frequency=args.log_frequency,
                               unrolled=args.unrolled,
                               callbacks=[
                                   LRSchedulerCallback(lr_scheduler),
                                   ArchitectureCheckpoint(args.checkpath)
                               ],
                               workers=0)
        if args.visualization:
            trainer.enable_visualization()

        trainer.train()
    else:
        from nni.retiarii.oneshot.pytorch import DartsTrainer
        trainer = DartsTrainer(model=model,
                               loss=criterion,
                               metrics=lambda output, target: accuracy(
                                   output, target, topk=(1, )),
                               optimizer=optim,
                               num_epochs=args.epochs,
                               dataset=dataset_train,