Пример #1
0
    def infer(self, graph, criterion, valid_queue, *args, **kwargs):
        try:
            config = kwargs.get('config', graph.config)
            device = kwargs['device']
        except:
            raise ('No configuration specified in graph or kwargs')

        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        graph.eval()

        with torch.no_grad():
            for step, (input, target) in enumerate(valid_queue):
                input = input.to(device)
                target = target.to(device, non_blocking=True)
                # logits, _ = graph(input)
                logits = graph(input)
                loss = criterion(logits, target)

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = input.size(0)
                objs.update(loss.data.item(), n)
                top1.update(prec1.data.item(), n)
                top5.update(prec5.data.item(), n)

                if step % config.report_freq == 0:
                    logging.info('valid %03d %e %f %f', step, objs.avg,
                                 top1.avg, top5.avg)

        return top1.avg, objs.avg
Пример #2
0
    def train(self, epoch, graph, optimizer, criterion, train_queue,
              valid_queue, *args, **kwargs):
        try:
            config = kwargs.get('config', graph.config)
            device = kwargs['device']
            arch_optimizer = kwargs['arch_optimizer']
        except Exception as e:
            raise ModuleNotFoundError(
                'No configuration specified in graph or kwargs')

        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        # Adjust arch optimizer for new search epoch
        arch_optimizer.new_epoch(epoch)

        start_time = time.time()
        for step, (input_train, target_train) in enumerate(train_queue):
            graph.train()
            n = input_train.size(0)

            input_train = input_train.to(device)
            target_train = target_train.to(device, non_blocking=True)

            # Architecture update
            arch_optimizer.forward_pass_adjustment()
            input_valid, target_valid = next(iter(valid_queue))
            input_valid = input_valid.to(device)
            target_valid = target_valid.to(device, non_blocking=True)

            arch_optimizer.step(graph, criterion, input_train, target_train,
                                input_valid, target_valid, self.lr,
                                self.optimizer, config.unrolled)
            optimizer.zero_grad()

            # OP-weight update
            arch_optimizer.forward_pass_adjustment()
            logits = graph(input_train)
            loss = criterion(logits, target_train)
            loss.backward()
            nn.utils.clip_grad_norm_(graph.parameters(), config.grad_clip)
            optimizer.step()

            prec1, prec5 = utils.accuracy(logits, target_train, topk=(1, 5))
            objs.update(loss.data.item(), n)
            top1.update(prec1.data.item(), n)
            top5.update(prec5.data.item(), n)

            if step % config.report_freq == 0:
                arch_key = list(
                    arch_optimizer.architectural_weights.keys())[-1]
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)

        end_time = time.time()
        return top1.avg, objs.avg, end_time - start_time
Пример #3
0
    def _store_accuracies(self, logits, target, split):
        """Update the accuracy counters"""
        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = logits.size(0)

        if split == 'train':
            self.train_top1.update(prec1.data.item(), n)
            self.train_top5.update(prec5.data.item(), n)
        elif split == 'val':
            self.val_top1.update(prec1.data.item(), n)
            self.val_top5.update(prec5.data.item(), n)
        else:
            raise ValueError("Unknown split: {}. Expected either 'train' or 'val'")
Пример #4
0
    def train_batch(self, arch):
        if self.steps % len(self.train_queue) == 0:
            self.scheduler.step()
            self.objs = utils.AvgrageMeter()
            self.top1 = utils.AvgrageMeter()
            self.top5 = utils.AvgrageMeter()
        lr = self.scheduler.get_lr()[0]

        weights = self.get_weights_from_arch(arch)
        self.set_arch_model_weights(weights)

        step = self.steps % len(self.train_queue)
        input, target = next(self.train_iter)

        self.model.train()
        n = input.size(0)

        input = input.cuda()
        target = target.cuda(non_blocking=True)

        # get a random_ws minibatch from the search queue with replacement
        self.optimizer.zero_grad()
        logits = self.model(input, discrete=True)
        loss = self.criterion(logits, target)

        loss.backward()
        nn.utils.clip_grad_norm(self.model.parameters(), self.args.grad_clip)
        self.optimizer.step()

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        self.objs.update(loss.data.item(), n)
        self.top1.update(prec1.data.item(), n)
        self.top5.update(prec5.data.item(), n)

        if step % self.args.report_freq == 0:
            logging.info('train %03d %e %f %f', step, self.objs.avg, self.top1.avg, self.top5.avg)

        self.steps += 1
        if self.steps % len(self.train_queue) == 0:
            # Save the model weights
            self.epochs += 1
            self.train_iter = iter(self.train_queue)
            valid_err = self.evaluate(arch)
            logging.info('epoch %d  |  train_acc %f  |  valid_acc %f' % (self.epochs, self.top1.avg, 1 - valid_err))
            self.save(epoch=self.epochs)
Пример #5
0
    def train(self, epoch, graph, optimizer, criterion, train_queue,
              valid_queue, *args, **kwargs):
        try:
            config = kwargs.get('config', graph.config)
            device = kwargs['device']
        except Exception as e:
            raise ModuleNotFoundError(
                'No configuration specified in graph or kwargs')

        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        start_time = time.time()
        for step, (input, target) in enumerate(train_queue):
            graph.train()
            n = input.size(0)

            input = input.to(device)
            target = target.to(device, non_blocking=True)

            optimizer.zero_grad()
            # logits, logits_aux = graph(input)
            logits = graph(input)
            loss = criterion(logits, target)
            # if config.auxiliary:
            #    loss_aux = criterion(logits_aux, target)
            #    loss += config.auxiliary_weight * loss_aux
            loss.backward()
            nn.utils.clip_grad_norm_(graph.parameters(), config.grad_clip)
            optimizer.step()

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            objs.update(loss.data.item(), n)
            top1.update(prec1.data.item(), n)
            top5.update(prec5.data.item(), n)

            if step % config.report_freq == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)

        end_time = time.time()
        return top1.avg, objs.avg, end_time - start_time
Пример #6
0
    def evaluate_test(self, arch, split=None, discrete=False, normalize=True):
        # Return error since we want to minimize obj val
        logging.info(arch)
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        weights = self.get_weights_from_arch(arch)
        self.set_arch_model_weights(weights)

        self.model.eval()

        if split is None:
            n_batches = 10
        else:
            n_batches = len(self.test_queue)

        for step in range(n_batches):
            try:
                input, target = next(self.test_iter)
            except Exception as e:
                logging.info('looping back over valid set')
                self.test_iter = iter(self.test_queue)
                input, target = next(self.test_iter)
            input = input.cuda()
            target = target.cuda(non_blocking=True)

            logits = self.model(input, discrete=discrete, normalize=normalize)
            loss = self.criterion(logits, target)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)
            objs.update(loss.data.item(), n)
            top1.update(prec1.data.item(), n)
            top5.update(prec5.data.item(), n)

            if step % self.args.report_freq == 0:
                logging.info('test %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

        return 1 - 0.01 * top1.avg
Пример #7
0
    def evaluate(
        self,
        retrain=True,
        search_model="",
        resume_from="",
        best_arch=None,
    ):
        """
        Evaluate the final architecture as given from the optimizer.

        If the search space has an interface to a benchmark then query that.
        Otherwise train as defined in the config.

        Args:
            retrain (bool): Reset the weights from the architecure search
            search_model (str): Path to checkpoint file that was created during
                search. If not provided, then try to load 'model_final.pth' from search
            resume_from (str): Resume retraining from the given checkpoint file.
            best_arch: Parsed model you want to directly evaluate and ignore the final model
                from the optimizer.
        """
        logger.info("Start evaluation")
        if not best_arch:

            if not search_model:
                search_model = os.path.join(self.config.save, "search",
                                            "model_final.pth")
            self._setup_checkpointers(
                search_model)  # required to load the architecture

            best_arch = self.optimizer.get_final_architecture()
        logger.info("Final architecture:\n" + best_arch.modules_str())

        if best_arch.QUERYABLE:
            metric = Metric.TEST_ACCURACY
            result = best_arch.query(metric=metric,
                                     dataset=self.config.dataset)
            logger.info("Queried results ({}): {}".format(metric, result))
        else:
            best_arch.to(self.device)
            if retrain:
                logger.info("Starting retraining from scratch")
                best_arch.reset_weights(inplace=True)

                self.train_queue, self.valid_queue, self.test_queue = self.build_eval_dataloaders(
                    self.config)

                optim = self.build_eval_optimizer(best_arch.parameters(),
                                                  self.config)
                scheduler = self.build_eval_scheduler(optim, self.config)

                start_epoch = self._setup_checkpointers(
                    resume_from,
                    search=False,
                    period=self.config.evaluation.checkpoint_freq,
                    model=best_arch,  # checkpointables start here
                    optim=optim,
                    scheduler=scheduler)

                grad_clip = self.config.evaluation.grad_clip
                loss = torch.nn.CrossEntropyLoss()

                best_arch.train()
                self.train_top1.reset()
                self.train_top5.reset()
                self.val_top1.reset()
                self.val_top5.reset()

                # Enable drop path
                best_arch.update_edges(update_func=lambda edge: edge.data.set(
                    'op', DropPathWrapper(edge.data.op)),
                                       scope=best_arch.OPTIMIZER_SCOPE,
                                       private_edge_data=True)

                # train from scratch
                epochs = self.config.evaluation.epochs
                for e in range(start_epoch, epochs):
                    if torch.cuda.is_available():
                        log_first_n(logging.INFO,
                                    "cuda consumption\n {}".format(
                                        torch.cuda.memory_summary()),
                                    n=20)

                    # update drop path probability
                    drop_path_prob = self.config.evaluation.drop_path_prob * e / epochs
                    best_arch.update_edges(
                        update_func=lambda edge: edge.data.set(
                            'drop_path_prob', drop_path_prob),
                        scope=best_arch.OPTIMIZER_SCOPE,
                        private_edge_data=True)

                    # Train queue
                    for i, (input_train,
                            target_train) in enumerate(self.train_queue):
                        input_train = input_train.to(self.device)
                        target_train = target_train.to(self.device,
                                                       non_blocking=True)

                        optim.zero_grad()
                        logits_train = best_arch(input_train)
                        train_loss = loss(logits_train, target_train)
                        if hasattr(best_arch,
                                   'auxilary_logits'):  # darts specific stuff
                            log_first_n(logging.INFO,
                                        "Auxiliary is used",
                                        n=10)
                            auxiliary_loss = loss(best_arch.auxilary_logits(),
                                                  target_train)
                            train_loss += self.config.evaluation.auxiliary_weight * auxiliary_loss
                        train_loss.backward()
                        if grad_clip:
                            torch.nn.utils.clip_grad_norm_(
                                best_arch.parameters(), grad_clip)
                        optim.step()

                        self._store_accuracies(logits_train, target_train,
                                               'train')
                        log_every_n_seconds(
                            logging.INFO,
                            "Epoch {}-{}, Train loss: {:.5}, learning rate: {}"
                            .format(e, i, train_loss, scheduler.get_last_lr()),
                            n=5)

                    # Validation queue
                    if self.valid_queue:
                        for i, (input_valid,
                                target_valid) in enumerate(self.valid_queue):

                            input_valid = input_valid.cuda().float()
                            target_valid = target_valid.cuda().float()

                            # just log the validation accuracy
                            with torch.no_grad():
                                logits_valid = best_arch(input_valid)
                                self._store_accuracies(logits_valid,
                                                       target_valid, 'val')

                    scheduler.step()
                    self.periodic_checkpointer.step(e)
                    self._log_and_reset_accuracies(e)

            # Disable drop path
            best_arch.update_edges(update_func=lambda edge: edge.data.set(
                'op', edge.data.op.get_embedded_ops()),
                                   scope=best_arch.OPTIMIZER_SCOPE,
                                   private_edge_data=True)

            # measure final test accuracy
            top1 = utils.AverageMeter()
            top5 = utils.AverageMeter()

            best_arch.eval()

            for i, data_test in enumerate(self.test_queue):
                input_test, target_test = data_test
                input_test = input_test.to(self.device)
                target_test = target_test.to(self.device, non_blocking=True)

                n = input_test.size(0)

                with torch.no_grad():
                    logits = best_arch(input_test)

                    prec1, prec5 = utils.accuracy(logits,
                                                  target_test,
                                                  topk=(1, 5))
                    top1.update(prec1.data.item(), n)
                    top5.update(prec5.data.item(), n)

                log_every_n_seconds(logging.INFO,
                                    "Inference batch {} of {}.".format(
                                        i, len(self.test_queue)),
                                    n=5)

            logger.info(
                "Evaluation finished. Test accuracies: top-1 = {:.5}, top-5 = {:.5}"
                .format(top1.avg, top5.avg))
Пример #8
0
    def evaluate(
            self,
            retrain=True,
            search_model="",
            resume_from="",
            best_arch=None,
    ):
        """
        Evaluate the final architecture as given from the optimizer.

        If the search space has an interface to a benchmark then query that.
        Otherwise train as defined in the config.

        Args:
            retrain (bool): Reset the weights from the architecure search
            search_model (str): Path to checkpoint file that was created during
                search. If not provided, then try to load 'model_final.pth' from search
            resume_from (str): Resume retraining from the given checkpoint file.
            multi_gpu (bool): Distribute training on multiple gpus.
            best_arch: Parsed model you want to directly evaluate and ignore the final model
                from the optimizer.
        """

        #best_arch.to(self.device)
        self.config.evaluation.resume_from = resume_from
        if retrain:
            if self.config.gpu is not None:
                logger.warning(
                    'You have chosen a specific GPU. This will completely \
                    disable data parallelism.'
                )

            if self.config.evaluation.dist_url == "env://" and self.config.evaluation.world_size == -1:
                self.config.evaluation.world_size = int(os.environ["WORLD_SIZE"])

            self.config.evaluation.distributed = \
                self.config.evaluation.world_size > 1 or self.config.evaluation.multiprocessing_distributed
            ngpus_per_node = torch.cuda.device_count()

            if self.config.evaluation.multiprocessing_distributed:
                # Since we have ngpus_per_node processes per node, the
                # total world_size needs to be adjusted
                self.config.evaluation.world_size = ngpus_per_node * self.config.evaluation.world_size
                # Use torch.multiprocessing.spawn to launch distributed
                # processes: the main_worker process function
                mp.spawn(self.main_worker, nprocs=ngpus_per_node,
                         args=(ngpus_per_node, self.config.evaluation,
                               search_model, best_arch))
            else:
                # Simply call main_worker function
                self.main_worker(self.config.gpu, ngpus_per_node,
                                 self.config.evaluation,
                                 search_model, best_arch)

        if not self.QUERYABLE:
            # Disable drop path
            best_arch.update_edges(
                update_func=lambda edge: edge.data.set('op', edge.data.op.get_embedded_ops()),
                scope=best_arch.OPTIMIZER_SCOPE,
                private_edge_data=True
            )

            # measure final test accuracy
            top1 = utils.AverageMeter()
            top5 = utils.AverageMeter()

            best_arch.eval()

            for i, data_test in enumerate(self.test_queue):
                input_test, target_test = data_test
                input_test = input_test.to(self.device)
                target_test = target_test.to(self.device, non_blocking=True)

                n = input_test.size(0)

                with torch.no_grad():
                    logits = best_arch(input_test)

                    prec1, prec5 = utils.accuracy(logits, target_test, topk=(1, 5))
                    top1.update(prec1.data.item(), n)
                    top5.update(prec5.data.item(), n)

                log_every_n_seconds(
                    logging.INFO,
                    "Inference batch {} of {}.".format(
                        i, len(self.test_queue)
                    ), n=5
                )

            logger.info("Evaluation finished. Test accuracies: top-1 = {:.5}, \
                        top-5 = {:.5}".format(top1.avg, top5.avg))