예제 #1
0
파일: trainer.py 프로젝트: kashankrm/NASLib
    def __init__(self, optimizer, config, lightweight_output=False):
        """
        Initializes the trainer.

        Args:
            optimizer: A NASLib optimizer
            config (AttrDict): The configuration loaded from a yaml file, e.g
                via  `utils.get_config_from_args()`
        """
        self.optimizer = optimizer
        self.config = config
        self.epochs = self.config.search.epochs
        self.lightweight_output = lightweight_output

        # preparations
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # measuring stuff
        self.train_top1 = utils.AverageMeter()
        self.train_top5 = utils.AverageMeter()
        self.train_loss = utils.AverageMeter()
        self.val_top1 = utils.AverageMeter()
        self.val_top5 = utils.AverageMeter()
        self.val_loss = utils.AverageMeter()

        n_parameters = optimizer.get_model_size()
        logger.info("param size = %fMB", n_parameters)
        self.errors_dict = utils.AttrDict(
            {'train_acc': [],
             'train_loss': [],
             'valid_acc': [],
             'valid_loss': [],
             'test_acc': [],
             'test_loss': [],
             'runtime': [],
             'train_time': [],
             'arch_eval': [],
             'params': n_parameters}
        )
예제 #2
0
파일: trainer.py 프로젝트: jackyvan/NASLib
    def evaluate(
        self,
        retrain=True,
        search_model="",
        resume_from="",
        best_arch=None,
    ):
        """
        Evaluate the final architecture as given from the optimizer.

        If the search space has an interface to a benchmark then query that.
        Otherwise train as defined in the config.

        Args:
            retrain (bool): Reset the weights from the architecure search
            search_model (str): Path to checkpoint file that was created during
                search. If not provided, then try to load 'model_final.pth' from search
            resume_from (str): Resume retraining from the given checkpoint file.
            best_arch: Parsed model you want to directly evaluate and ignore the final model
                from the optimizer.
        """
        logger.info("Start evaluation")
        if not best_arch:

            if not search_model:
                search_model = os.path.join(self.config.save, "search",
                                            "model_final.pth")
            self._setup_checkpointers(
                search_model)  # required to load the architecture

            best_arch = self.optimizer.get_final_architecture()
        logger.info("Final architecture:\n" + best_arch.modules_str())

        if best_arch.QUERYABLE:
            metric = Metric.TEST_ACCURACY
            result = best_arch.query(metric=metric,
                                     dataset=self.config.dataset)
            logger.info("Queried results ({}): {}".format(metric, result))
        else:
            best_arch.to(self.device)
            if retrain:
                logger.info("Starting retraining from scratch")
                best_arch.reset_weights(inplace=True)

                self.train_queue, self.valid_queue, self.test_queue = self.build_eval_dataloaders(
                    self.config)

                optim = self.build_eval_optimizer(best_arch.parameters(),
                                                  self.config)
                scheduler = self.build_eval_scheduler(optim, self.config)

                start_epoch = self._setup_checkpointers(
                    resume_from,
                    search=False,
                    period=self.config.evaluation.checkpoint_freq,
                    model=best_arch,  # checkpointables start here
                    optim=optim,
                    scheduler=scheduler)

                grad_clip = self.config.evaluation.grad_clip
                loss = torch.nn.CrossEntropyLoss()

                best_arch.train()
                self.train_top1.reset()
                self.train_top5.reset()
                self.val_top1.reset()
                self.val_top5.reset()

                # Enable drop path
                best_arch.update_edges(update_func=lambda edge: edge.data.set(
                    'op', DropPathWrapper(edge.data.op)),
                                       scope=best_arch.OPTIMIZER_SCOPE,
                                       private_edge_data=True)

                # train from scratch
                epochs = self.config.evaluation.epochs
                for e in range(start_epoch, epochs):
                    if torch.cuda.is_available():
                        log_first_n(logging.INFO,
                                    "cuda consumption\n {}".format(
                                        torch.cuda.memory_summary()),
                                    n=20)

                    # update drop path probability
                    drop_path_prob = self.config.evaluation.drop_path_prob * e / epochs
                    best_arch.update_edges(
                        update_func=lambda edge: edge.data.set(
                            'drop_path_prob', drop_path_prob),
                        scope=best_arch.OPTIMIZER_SCOPE,
                        private_edge_data=True)

                    # Train queue
                    for i, (input_train,
                            target_train) in enumerate(self.train_queue):
                        input_train = input_train.to(self.device)
                        target_train = target_train.to(self.device,
                                                       non_blocking=True)

                        optim.zero_grad()
                        logits_train = best_arch(input_train)
                        train_loss = loss(logits_train, target_train)
                        if hasattr(best_arch,
                                   'auxilary_logits'):  # darts specific stuff
                            log_first_n(logging.INFO,
                                        "Auxiliary is used",
                                        n=10)
                            auxiliary_loss = loss(best_arch.auxilary_logits(),
                                                  target_train)
                            train_loss += self.config.evaluation.auxiliary_weight * auxiliary_loss
                        train_loss.backward()
                        if grad_clip:
                            torch.nn.utils.clip_grad_norm_(
                                best_arch.parameters(), grad_clip)
                        optim.step()

                        self._store_accuracies(logits_train, target_train,
                                               'train')
                        log_every_n_seconds(
                            logging.INFO,
                            "Epoch {}-{}, Train loss: {:.5}, learning rate: {}"
                            .format(e, i, train_loss, scheduler.get_last_lr()),
                            n=5)

                    # Validation queue
                    if self.valid_queue:
                        for i, (input_valid,
                                target_valid) in enumerate(self.valid_queue):

                            input_valid = input_valid.cuda().float()
                            target_valid = target_valid.cuda().float()

                            # just log the validation accuracy
                            with torch.no_grad():
                                logits_valid = best_arch(input_valid)
                                self._store_accuracies(logits_valid,
                                                       target_valid, 'val')

                    scheduler.step()
                    self.periodic_checkpointer.step(e)
                    self._log_and_reset_accuracies(e)

            # Disable drop path
            best_arch.update_edges(update_func=lambda edge: edge.data.set(
                'op', edge.data.op.get_embedded_ops()),
                                   scope=best_arch.OPTIMIZER_SCOPE,
                                   private_edge_data=True)

            # measure final test accuracy
            top1 = utils.AverageMeter()
            top5 = utils.AverageMeter()

            best_arch.eval()

            for i, data_test in enumerate(self.test_queue):
                input_test, target_test = data_test
                input_test = input_test.to(self.device)
                target_test = target_test.to(self.device, non_blocking=True)

                n = input_test.size(0)

                with torch.no_grad():
                    logits = best_arch(input_test)

                    prec1, prec5 = utils.accuracy(logits,
                                                  target_test,
                                                  topk=(1, 5))
                    top1.update(prec1.data.item(), n)
                    top5.update(prec5.data.item(), n)

                log_every_n_seconds(logging.INFO,
                                    "Inference batch {} of {}.".format(
                                        i, len(self.test_queue)),
                                    n=5)

            logger.info(
                "Evaluation finished. Test accuracies: top-1 = {:.5}, top-5 = {:.5}"
                .format(top1.avg, top5.avg))
예제 #3
0
    def evaluate(
            self,
            retrain=True,
            search_model="",
            resume_from="",
            best_arch=None,
    ):
        """
        Evaluate the final architecture as given from the optimizer.

        If the search space has an interface to a benchmark then query that.
        Otherwise train as defined in the config.

        Args:
            retrain (bool): Reset the weights from the architecure search
            search_model (str): Path to checkpoint file that was created during
                search. If not provided, then try to load 'model_final.pth' from search
            resume_from (str): Resume retraining from the given checkpoint file.
            multi_gpu (bool): Distribute training on multiple gpus.
            best_arch: Parsed model you want to directly evaluate and ignore the final model
                from the optimizer.
        """

        #best_arch.to(self.device)
        self.config.evaluation.resume_from = resume_from
        if retrain:
            if self.config.gpu is not None:
                logger.warning(
                    'You have chosen a specific GPU. This will completely \
                    disable data parallelism.'
                )

            if self.config.evaluation.dist_url == "env://" and self.config.evaluation.world_size == -1:
                self.config.evaluation.world_size = int(os.environ["WORLD_SIZE"])

            self.config.evaluation.distributed = \
                self.config.evaluation.world_size > 1 or self.config.evaluation.multiprocessing_distributed
            ngpus_per_node = torch.cuda.device_count()

            if self.config.evaluation.multiprocessing_distributed:
                # Since we have ngpus_per_node processes per node, the
                # total world_size needs to be adjusted
                self.config.evaluation.world_size = ngpus_per_node * self.config.evaluation.world_size
                # Use torch.multiprocessing.spawn to launch distributed
                # processes: the main_worker process function
                mp.spawn(self.main_worker, nprocs=ngpus_per_node,
                         args=(ngpus_per_node, self.config.evaluation,
                               search_model, best_arch))
            else:
                # Simply call main_worker function
                self.main_worker(self.config.gpu, ngpus_per_node,
                                 self.config.evaluation,
                                 search_model, best_arch)

        if not self.QUERYABLE:
            # Disable drop path
            best_arch.update_edges(
                update_func=lambda edge: edge.data.set('op', edge.data.op.get_embedded_ops()),
                scope=best_arch.OPTIMIZER_SCOPE,
                private_edge_data=True
            )

            # measure final test accuracy
            top1 = utils.AverageMeter()
            top5 = utils.AverageMeter()

            best_arch.eval()

            for i, data_test in enumerate(self.test_queue):
                input_test, target_test = data_test
                input_test = input_test.to(self.device)
                target_test = target_test.to(self.device, non_blocking=True)

                n = input_test.size(0)

                with torch.no_grad():
                    logits = best_arch(input_test)

                    prec1, prec5 = utils.accuracy(logits, target_test, topk=(1, 5))
                    top1.update(prec1.data.item(), n)
                    top5.update(prec5.data.item(), n)

                log_every_n_seconds(
                    logging.INFO,
                    "Inference batch {} of {}.".format(
                        i, len(self.test_queue)
                    ), n=5
                )

            logger.info("Evaluation finished. Test accuracies: top-1 = {:.5}, \
                        top-5 = {:.5}".format(top1.avg, top5.avg))