def train(self): switches = None for epoch in range(self.pdarts_epoch): layers = self.init_layers + self.pdarts_num_layers[epoch] model, criterion, optim, lr_scheduler = self.model_creator(layers) self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) for callback in self.callbacks: callback.build(model, self.mutator, self) callback.on_epoch_begin(epoch) darts_callbacks = [] if lr_scheduler is not None: darts_callbacks.append(LRSchedulerCallback(lr_scheduler)) self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim, callbacks=darts_callbacks, **self.darts_parameters) logger.info("start pdarts training epoch %s...", epoch) self.trainer.train() switches = self.mutator.drop_paths() for callback in self.callbacks: callback.on_epoch_end(epoch)
momentum=0.9, weight_decay=3.0E-4) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) if args.v1: from nni.algorithms.nas.pytorch.darts import DartsTrainer trainer = DartsTrainer(model, device=torch.device("cuda:{}".format(args.gpu)), loss=criterion, metrics=lambda output, target: accuracy( output, target, topk=(1, )), optimizer=optim, num_epochs=args.epochs, dataset_train=dataset_train, dataset_valid=dataset_valid, batch_size=args.batch_size, log_frequency=args.log_frequency, unrolled=args.unrolled, callbacks=[ LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints") ]) if args.visualization: trainer.enable_visualization() trainer.train() else: from nni.retiarii.oneshot.pytorch import DartsTrainer trainer = DartsTrainer(model=model, loss=criterion,
model = DartsStackedCells(3, args.channels, 10, args.layers, DartsCell) criterion = nn.CrossEntropyLoss() optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) trainer = DartsTrainer( model, loss=criterion, metrics=lambda output, target: accuracy(output, target, topk=(1, )), optimizer=optim, num_epochs=args.epochs, dataset_train=dataset_train, dataset_valid=dataset_valid, batch_size=args.batch_size, log_frequency=args.log_frequency, unrolled=args.unrolled, callbacks=[ LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints") ]) if args.visualization: trainer.enable_visualization() trainer.train()
model = DartsStackedCells(3, args.channels, 10, args.layers) model = model.to(device) trainset, testset = getDatasets() if __name__ == '__main__': criterion = nn.CrossEntropyLoss() # optimizer = optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) optimizer = optim.Adamax(model.parameters(), lr=0.025) lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0.001) trainer = DartsTrainer( model, criterion, lambda outputs, labels: getAccuracy(outputs, labels), optimizer, args.epochs, trainset, testset, batch_size=args.batch_size, log_frequency=args.log_frequency, unrolled=args.unrolled, callbacks=[ LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint('./checkpoints') ]) if args.visualization: trainer.enable_visualization() trainer.train()
class PdartsTrainer(BaseTrainer): """ This trainer implements the PDARTS algorithm. PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network. This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows. pdarts_num_layers means how many layers more than first epoch. pdarts_num_to_drop means how many candidate operations should be dropped in each epoch. So that the grew network can in similar size. """ def __init__(self, model_creator, init_layers, metrics, num_epochs, dataset_train, dataset_valid, pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 1], mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, unrolled=False): super(PdartsTrainer, self).__init__() self.model_creator = model_creator self.init_layers = init_layers self.pdarts_num_layers = pdarts_num_layers self.pdarts_num_to_drop = pdarts_num_to_drop self.pdarts_epoch = len(pdarts_num_to_drop) self.darts_parameters = { "metrics": metrics, "num_epochs": num_epochs, "dataset_train": dataset_train, "dataset_valid": dataset_valid, "batch_size": batch_size, "workers": workers, "device": device, "log_frequency": log_frequency, "unrolled": unrolled } self.callbacks = callbacks if callbacks is not None else [] def train(self): switches = None for epoch in range(self.pdarts_epoch): layers = self.init_layers + self.pdarts_num_layers[epoch] model, criterion, optim, lr_scheduler = self.model_creator(layers) self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) for callback in self.callbacks: callback.build(model, self.mutator, self) callback.on_epoch_begin(epoch) darts_callbacks = [] if lr_scheduler is not None: darts_callbacks.append(LRSchedulerCallback(lr_scheduler)) self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim, callbacks=darts_callbacks, **self.darts_parameters) logger.info("start pdarts training epoch %s...", epoch) self.trainer.train() switches = self.mutator.drop_paths() for callback in self.callbacks: callback.on_epoch_end(epoch) def validate(self): self.trainer.validate() def export(self, file): mutator_export = self.mutator.export() with open(file, "w") as f: json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) def checkpoint(self): raise NotImplementedError("Not implemented yet")
# If classes are not eqully distributed then give them weights(each class) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") weight=(torch.tensor([0.5,4.0], dtype=torch.float)).to(device) criterion = nn.CrossEntropyLoss(weight=weight) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epochs, eta_min=0.001) if args.train_mode == 'search': from nni.algorithms.nas.pytorch.darts import DartsTrainer from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback trainer = DartsTrainer(model, loss = criterion, metrics = lambda output, target: f1_(output, target), optimizer = optimizer, device = device, batch_size = args.batch_size, num_epochs = args.num_epochs, dataset_train = train_dataset, dataset_valid = test_dataset, log_frequency = args.log_frequency) if args.visualization: trainer.enable_visualization() logger.info('Start to train with DARTS...') trainer.train() logger.info('Training done') trainer.export(file=args.arch_path) logger.info('Best architecture exported in %s', args.arch_path) elif args.train_mode == 'retrain': from retrain import Retrain from nni.nas.pytorch.fixed import apply_fixed_architecture
if __name__ == "__main__": transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset_train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) dataset_valid = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) net = Net() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) trainer = DartsTrainer(net, loss=criterion, metrics=accuracy, optimizer=optimizer, num_epochs=2, dataset_train=dataset_train, dataset_valid=dataset_valid, batch_size=64, log_frequency=10) trainer.enable_visualization() trainer.train() trainer.export("checkpoint.json")