Ejemplo n.º 1
0
    def train(self):
        layers = self.layers
        switches = None
        for epoch in range(self.pdarts_epoch):

            layers = self.layers+self.pdarts_num_layers[epoch]
            model, criterion, optim, lr_scheduler = self.model_creator(layers)
            self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches)

            for callback in self.callbacks:
                callback.build(model, self.mutator, self)
                callback.on_epoch_begin(epoch)

            darts_callbacks = []
            if lr_scheduler is not None:
                darts_callbacks.append(LRSchedulerCallback(lr_scheduler))

            self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
                                        callbacks=darts_callbacks, **self.darts_parameters)
            logger.info("start pdarts training epoch %s...", epoch)

            self.trainer.train()

            switches = self.mutator.drop_paths()

            for callback in self.callbacks:
                callback.on_epoch_end(epoch)
Ejemplo n.º 2
0
def main(args):
    reset_seed(args.seed)
    prepare_logger(args)

    logger.info("These are the hyper-parameters you want to tune:\n%s",
                pprint.pformat(vars(args)))

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    train_loader, test_loader = data_preprocess(args)
    # model = models.__dict__[args.model](num_classes=10)
    model = CNN(32, 3, args.channels, 10, args.layers)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.initial_lr,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'sgd':
            optimizer_cls = optim.SGD
        elif args.optimizer == 'rmsprop':
            optimizer_cls = optim.RMSprop
        optimizer = optimizer_cls(model.parameters(),
                                  lr=args.initial_lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)

    if args.lr_scheduler == 'cosin':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.epochs, eta_min=args.ending_lr)
    elif args.lr_scheduler == 'linear':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=15,
                                              gamma=0.1)

    trainer = DartsTrainer(
        model,
        loss=criterion,
        metrics=lambda output, target: accuracy(output, target),
        optimizer=optimizer,
        num_epochs=args.epochs,
        dataset_train=train_loader,
        dataset_valid=test_loader,
        batch_size=args.batch_size,
        log_frequency=args.log_frequency,
        unrolled=args.unrolled,
        callbacks=[
            LRSchedulerCallback(scheduler),
            ArchitectureCheckpoint("./checkpoints_layer5")
        ])

    if args.visualization:
        trainer.enable_visualization()
    trainer.train()
Ejemplo n.º 3
0
class PdartsTrainer(BaseTrainer):

    def __init__(self, model_creator, layers, metrics,
                 num_epochs, dataset_train, dataset_valid,
                 pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2],
                 mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None):
        super(PdartsTrainer, self).__init__()
        self.model_creator = model_creator
        self.layers = layers
        self.pdarts_num_layers = pdarts_num_layers
        self.pdarts_num_to_drop = pdarts_num_to_drop
        self.pdarts_epoch = len(pdarts_num_to_drop)
        self.darts_parameters = {
            "metrics": metrics,
            "num_epochs": num_epochs,
            "dataset_train": dataset_train,
            "dataset_valid": dataset_valid,
            "batch_size": batch_size,
            "workers": workers,
            "device": device,
            "log_frequency": log_frequency
        }
        self.callbacks = callbacks if callbacks is not None else []

    def train(self):
        layers = self.layers
        switches = None
        for epoch in range(self.pdarts_epoch):

            layers = self.layers+self.pdarts_num_layers[epoch]
            model, criterion, optim, lr_scheduler = self.model_creator(layers)
            self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches)

            for callback in self.callbacks:
                callback.build(model, self.mutator, self)
                callback.on_epoch_begin(epoch)

            darts_callbacks = []
            if lr_scheduler is not None:
                darts_callbacks.append(LRSchedulerCallback(lr_scheduler))

            self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
                                        callbacks=darts_callbacks, **self.darts_parameters)
            logger.info("start pdarts training epoch %s...", epoch)

            self.trainer.train()

            switches = self.mutator.drop_paths()

            for callback in self.callbacks:
                callback.on_epoch_end(epoch)

    def validate(self):
        self.model.validate()

    def export(self, file):
        mutator_export = self.mutator.export()
        with open(file, "w") as f:
            json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)

    def checkpoint(self):
        raise NotImplementedError("Not implemented yet")
Ejemplo n.º 4
0
    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

    model = CNN(32, 3, 16, 10, args.layers)
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(),
                            0.025,
                            momentum=0.9,
                            weight_decay=3.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim,
                                                              args.epochs,
                                                              eta_min=0.001)

    trainer = DartsTrainer(
        model,
        loss=criterion,
        metrics=lambda output, target: accuracy(output, target, topk=(1, )),
        optimizer=optim,
        num_epochs=args.epochs,
        dataset_train=dataset_train,
        dataset_valid=dataset_valid,
        batch_size=args.batch_size,
        log_frequency=args.log_frequency,
        unrolled=args.unrolled,
        callbacks=[
            LRSchedulerCallback(lr_scheduler),
            ArchitectureCheckpoint("./checkpoints")
        ])
    trainer.train()
Ejemplo n.º 5
0
def main(args):
    reset_seed(args.seed)
    prepare_logger(args)

    logger.info("These are the hyper-parameters you want to tune:\n%s",
                pprint.pformat(vars(args)))

    if args.model == 'nas':
        logger.info("Using NAS.\n")
        if args.fix_arch:
            if not os.path.exists(args.arc_checkpoint):
                print(args.arc_checkpoint,
                      'does not exist, don not fix archetect')
                args.fix_arch = False

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if args.model == 'nas':
        if not args.fix_arch:
            model = CNN(32, 3, args.channels, 10, args.layers)
            trainset, testset = data_preprocess(args)
        else:
            model = CNN(32, 3, args.channels, 10, args.layers)
            apply_fixed_architecture(model, args.arc_checkpoint)
            model.to(device)
            train_loader, test_loader = data_preprocess(args)
    else:
        train_loader, test_loader = data_preprocess(args)
        model = models.__dict__[args.model]()
        model.to(device)

    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.initial_lr,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'sgd':
            optimizer_cls = optim.SGD
        elif args.optimizer == 'rmsprop':
            optimizer_cls = optim.RMSprop
        optimizer = optimizer_cls(model.parameters(),
                                  lr=args.initial_lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     args.epochs,
                                                     eta_min=args.ending_lr)

    if args.model == 'nas' and not args.fix_arch:
        trainer = DartsTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracyTopk(
                                   output, target, topk=(1, )),
                               optimizer=optimizer,
                               num_epochs=args.epochs,
                               dataset_train=trainset,
                               dataset_valid=testset,
                               batch_size=args.batch_size,
                               log_frequency=args.log_frequency,
                               unrolled=args.unrolled,
                               callbacks=[
                                   LRSchedulerCallback(scheduler),
                                   ArchitectureCheckpoint("./checkpoints")
                               ])
        if args.visualization:
            trainer.enable_visualization()
        trainer.train()
        trainer.export("final_arch.json")
    else:
        for epoch in range(1, args.epochs + 1):
            train(model, train_loader, criterion, optimizer, scheduler, args,
                  epoch, device)
            top1, _ = test(model, test_loader, criterion, args, epoch, device)
            nni.report_intermediate_result(top1)
        logger.info("Final accuracy is: %.6f", top1)
        nni.report_final_result(top1)
Ejemplo n.º 6
0
if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset_train = torchvision.datasets.CIFAR10(root="./data",
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
    dataset_valid = torchvision.datasets.CIFAR10(root="./data",
                                                 train=False,
                                                 download=True,
                                                 transform=transform)

    net = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    trainer = DartsTrainer(net,
                           loss=criterion,
                           metrics=accuracy,
                           optimizer=optimizer,
                           num_epochs=2,
                           dataset_train=dataset_train,
                           dataset_valid=dataset_valid,
                           batch_size=64,
                           log_frequency=10)
    trainer.enable_visualization()
    trainer.train()
    trainer.export("checkpoint.json")
Ejemplo n.º 7
0
Archivo: trainer.py Proyecto: zyzzu/nni
class PdartsTrainer(BaseTrainer):
    """
    This trainer implements the PDARTS algorithm.
    PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
    This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
    pdarts_num_layers means how many layers more than first epoch.
    pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
        So that the grew network can in similar size.
    """
    def __init__(self,
                 model_creator,
                 init_layers,
                 metrics,
                 num_epochs,
                 dataset_train,
                 dataset_valid,
                 pdarts_num_layers=[0, 6, 12],
                 pdarts_num_to_drop=[3, 2, 1],
                 mutator=None,
                 batch_size=64,
                 workers=4,
                 device=None,
                 log_frequency=None,
                 callbacks=None,
                 unrolled=False):
        super(PdartsTrainer, self).__init__()
        self.model_creator = model_creator
        self.init_layers = init_layers
        self.pdarts_num_layers = pdarts_num_layers
        self.pdarts_num_to_drop = pdarts_num_to_drop
        self.pdarts_epoch = len(pdarts_num_to_drop)
        self.darts_parameters = {
            "metrics": metrics,
            "num_epochs": num_epochs,
            "dataset_train": dataset_train,
            "dataset_valid": dataset_valid,
            "batch_size": batch_size,
            "workers": workers,
            "device": device,
            "log_frequency": log_frequency,
            "unrolled": unrolled
        }
        self.callbacks = callbacks if callbacks is not None else []

    def train(self):

        switches = None
        for epoch in range(self.pdarts_epoch):

            layers = self.init_layers + self.pdarts_num_layers[epoch]
            model, criterion, optim, lr_scheduler = self.model_creator(layers)
            self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop,
                                         switches)

            for callback in self.callbacks:
                callback.build(model, self.mutator, self)
                callback.on_epoch_begin(epoch)

            darts_callbacks = []
            if lr_scheduler is not None:
                darts_callbacks.append(LRSchedulerCallback(lr_scheduler))

            self.trainer = DartsTrainer(model,
                                        mutator=self.mutator,
                                        loss=criterion,
                                        optimizer=optim,
                                        callbacks=darts_callbacks,
                                        **self.darts_parameters)
            logger.info("start pdarts training epoch %s...", epoch)

            self.trainer.train()

            switches = self.mutator.drop_paths()

            for callback in self.callbacks:
                callback.on_epoch_end(epoch)

    def validate(self):
        self.trainer.validate()

    def export(self, file):
        mutator_export = self.mutator.export()
        with open(file, "w") as f:
            json.dump(mutator_export,
                      f,
                      indent=2,
                      sort_keys=True,
                      cls=TorchTensorEncoder)

    def checkpoint(self):
        raise NotImplementedError("Not implemented yet")