def __init__(self, optimizer, config, lightweight_output=False): """ Initializes the trainer. Args: optimizer: A NASLib optimizer config (AttrDict): The configuration loaded from a yaml file, e.g via `utils.get_config_from_args()` """ self.optimizer = optimizer self.config = config self.epochs = self.config.search.epochs self.lightweight_output = lightweight_output # preparations self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # measuring stuff self.train_top1 = utils.AverageMeter() self.train_top5 = utils.AverageMeter() self.train_loss = utils.AverageMeter() self.val_top1 = utils.AverageMeter() self.val_top5 = utils.AverageMeter() self.val_loss = utils.AverageMeter() n_parameters = optimizer.get_model_size() logger.info("param size = %fMB", n_parameters) self.errors_dict = utils.AttrDict( {'train_acc': [], 'train_loss': [], 'valid_acc': [], 'valid_loss': [], 'test_acc': [], 'test_loss': [], 'runtime': [], 'train_time': [], 'arch_eval': [], 'params': n_parameters} )
def evaluate( self, retrain=True, search_model="", resume_from="", best_arch=None, ): """ Evaluate the final architecture as given from the optimizer. If the search space has an interface to a benchmark then query that. Otherwise train as defined in the config. Args: retrain (bool): Reset the weights from the architecure search search_model (str): Path to checkpoint file that was created during search. If not provided, then try to load 'model_final.pth' from search resume_from (str): Resume retraining from the given checkpoint file. best_arch: Parsed model you want to directly evaluate and ignore the final model from the optimizer. """ logger.info("Start evaluation") if not best_arch: if not search_model: search_model = os.path.join(self.config.save, "search", "model_final.pth") self._setup_checkpointers( search_model) # required to load the architecture best_arch = self.optimizer.get_final_architecture() logger.info("Final architecture:\n" + best_arch.modules_str()) if best_arch.QUERYABLE: metric = Metric.TEST_ACCURACY result = best_arch.query(metric=metric, dataset=self.config.dataset) logger.info("Queried results ({}): {}".format(metric, result)) else: best_arch.to(self.device) if retrain: logger.info("Starting retraining from scratch") best_arch.reset_weights(inplace=True) self.train_queue, self.valid_queue, self.test_queue = self.build_eval_dataloaders( self.config) optim = self.build_eval_optimizer(best_arch.parameters(), self.config) scheduler = self.build_eval_scheduler(optim, self.config) start_epoch = self._setup_checkpointers( resume_from, search=False, period=self.config.evaluation.checkpoint_freq, model=best_arch, # checkpointables start here optim=optim, scheduler=scheduler) grad_clip = self.config.evaluation.grad_clip loss = torch.nn.CrossEntropyLoss() best_arch.train() self.train_top1.reset() self.train_top5.reset() self.val_top1.reset() self.val_top5.reset() # Enable drop path best_arch.update_edges(update_func=lambda edge: edge.data.set( 'op', DropPathWrapper(edge.data.op)), scope=best_arch.OPTIMIZER_SCOPE, private_edge_data=True) # train from scratch epochs = self.config.evaluation.epochs for e in range(start_epoch, epochs): if torch.cuda.is_available(): log_first_n(logging.INFO, "cuda consumption\n {}".format( torch.cuda.memory_summary()), n=20) # update drop path probability drop_path_prob = self.config.evaluation.drop_path_prob * e / epochs best_arch.update_edges( update_func=lambda edge: edge.data.set( 'drop_path_prob', drop_path_prob), scope=best_arch.OPTIMIZER_SCOPE, private_edge_data=True) # Train queue for i, (input_train, target_train) in enumerate(self.train_queue): input_train = input_train.to(self.device) target_train = target_train.to(self.device, non_blocking=True) optim.zero_grad() logits_train = best_arch(input_train) train_loss = loss(logits_train, target_train) if hasattr(best_arch, 'auxilary_logits'): # darts specific stuff log_first_n(logging.INFO, "Auxiliary is used", n=10) auxiliary_loss = loss(best_arch.auxilary_logits(), target_train) train_loss += self.config.evaluation.auxiliary_weight * auxiliary_loss train_loss.backward() if grad_clip: torch.nn.utils.clip_grad_norm_( best_arch.parameters(), grad_clip) optim.step() self._store_accuracies(logits_train, target_train, 'train') log_every_n_seconds( logging.INFO, "Epoch {}-{}, Train loss: {:.5}, learning rate: {}" .format(e, i, train_loss, scheduler.get_last_lr()), n=5) # Validation queue if self.valid_queue: for i, (input_valid, target_valid) in enumerate(self.valid_queue): input_valid = input_valid.cuda().float() target_valid = target_valid.cuda().float() # just log the validation accuracy with torch.no_grad(): logits_valid = best_arch(input_valid) self._store_accuracies(logits_valid, target_valid, 'val') scheduler.step() self.periodic_checkpointer.step(e) self._log_and_reset_accuracies(e) # Disable drop path best_arch.update_edges(update_func=lambda edge: edge.data.set( 'op', edge.data.op.get_embedded_ops()), scope=best_arch.OPTIMIZER_SCOPE, private_edge_data=True) # measure final test accuracy top1 = utils.AverageMeter() top5 = utils.AverageMeter() best_arch.eval() for i, data_test in enumerate(self.test_queue): input_test, target_test = data_test input_test = input_test.to(self.device) target_test = target_test.to(self.device, non_blocking=True) n = input_test.size(0) with torch.no_grad(): logits = best_arch(input_test) prec1, prec5 = utils.accuracy(logits, target_test, topk=(1, 5)) top1.update(prec1.data.item(), n) top5.update(prec5.data.item(), n) log_every_n_seconds(logging.INFO, "Inference batch {} of {}.".format( i, len(self.test_queue)), n=5) logger.info( "Evaluation finished. Test accuracies: top-1 = {:.5}, top-5 = {:.5}" .format(top1.avg, top5.avg))
def evaluate( self, retrain=True, search_model="", resume_from="", best_arch=None, ): """ Evaluate the final architecture as given from the optimizer. If the search space has an interface to a benchmark then query that. Otherwise train as defined in the config. Args: retrain (bool): Reset the weights from the architecure search search_model (str): Path to checkpoint file that was created during search. If not provided, then try to load 'model_final.pth' from search resume_from (str): Resume retraining from the given checkpoint file. multi_gpu (bool): Distribute training on multiple gpus. best_arch: Parsed model you want to directly evaluate and ignore the final model from the optimizer. """ #best_arch.to(self.device) self.config.evaluation.resume_from = resume_from if retrain: if self.config.gpu is not None: logger.warning( 'You have chosen a specific GPU. This will completely \ disable data parallelism.' ) if self.config.evaluation.dist_url == "env://" and self.config.evaluation.world_size == -1: self.config.evaluation.world_size = int(os.environ["WORLD_SIZE"]) self.config.evaluation.distributed = \ self.config.evaluation.world_size > 1 or self.config.evaluation.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() if self.config.evaluation.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the # total world_size needs to be adjusted self.config.evaluation.world_size = ngpus_per_node * self.config.evaluation.world_size # Use torch.multiprocessing.spawn to launch distributed # processes: the main_worker process function mp.spawn(self.main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, self.config.evaluation, search_model, best_arch)) else: # Simply call main_worker function self.main_worker(self.config.gpu, ngpus_per_node, self.config.evaluation, search_model, best_arch) if not self.QUERYABLE: # Disable drop path best_arch.update_edges( update_func=lambda edge: edge.data.set('op', edge.data.op.get_embedded_ops()), scope=best_arch.OPTIMIZER_SCOPE, private_edge_data=True ) # measure final test accuracy top1 = utils.AverageMeter() top5 = utils.AverageMeter() best_arch.eval() for i, data_test in enumerate(self.test_queue): input_test, target_test = data_test input_test = input_test.to(self.device) target_test = target_test.to(self.device, non_blocking=True) n = input_test.size(0) with torch.no_grad(): logits = best_arch(input_test) prec1, prec5 = utils.accuracy(logits, target_test, topk=(1, 5)) top1.update(prec1.data.item(), n) top5.update(prec5.data.item(), n) log_every_n_seconds( logging.INFO, "Inference batch {} of {}.".format( i, len(self.test_queue) ), n=5 ) logger.info("Evaluation finished. Test accuracies: top-1 = {:.5}, \ top-5 = {:.5}".format(top1.avg, top5.avg))