Exemplo n.º 1
0
    def _verify_update_function(update_func: callable,
                                private_edge_data: bool):
        """
        Verify that the update function actually modifies only
        shared/private edge data attributes based on setting of
        `private_edge_data`.

        Args:
            update_func (callable): callable that expects one argument
                named `current_edge_data`.
            private_edge_data (bool): Whether the update function is applied
                to all graph instances including copies or just to one instance
                per graph
        """

        test = EdgeData()
        test.set('shared', True, shared=True)
        test.set('op', [True])

        try:
            result = test.clone()
            update_func(current_edge_data=result)
        except:
            log_first_n(
                logging.WARN,
                "Update function could not be veryfied. Be cautious with the "
                "setting of `private_edge_data` in `update_edges()`",
                n=5)
            return

        assert isinstance(
            result,
            EdgeData), "Update function does not return the edge data object."

        if private_edge_data:
            assert result._shared == test._shared, \
                "The update function changes shared data although `private_edge_data` set to True. " \
                "This is not the indended use of `update_edges`. The update function should only modify " \
                "private edge data."
        else:
            assert result._private == test._private, \
                "The update function changes private data although `private_edge_data` set to False. " \
                "This is not the indended use of `update_edges`. The update function should only modify " \
                "shared edge data."
Exemplo n.º 2
0
    def search(self, resume_from=""):
        """
        Start the architecture search.

        Generates a json file with training statistics.

        Args:
            resume_from (str): Checkpoint file to resume from. If not given then
                train from scratch.
        """
        logger.info("Start training")
        self.optimizer.before_training()
        checkpoint_freq = self.config.search.checkpoint_freq
        if self.optimizer.using_step_function:
            self.scheduler = self.build_search_scheduler(
                self.optimizer.op_optimizer, self.config)

            start_epoch = self._setup_checkpointers(resume_from,
                                                    period=checkpoint_freq,
                                                    scheduler=self.scheduler)
        else:
            start_epoch = self._setup_checkpointers(resume_from,
                                                    period=checkpoint_freq)

        self.train_queue, self.valid_queue, _ = self.build_search_dataloaders(
            self.config)

        for e in range(start_epoch, self.epochs):
            self.optimizer.new_epoch(e)

            start_time = time.time()
            if self.optimizer.using_step_function:
                for step, (data_train, data_val) in enumerate(
                        zip(self.train_queue, self.valid_queue)):
                    data_train = (data_train[0].to(self.device),
                                  data_train[1].to(self.device,
                                                   non_blocking=True))
                    data_val = (data_val[0].to(self.device),
                                data_val[1].to(self.device, non_blocking=True))

                    stats = self.optimizer.step(data_train, data_val)
                    logits_train, logits_val, train_loss, val_loss = stats

                    self._store_accuracies(logits_train, data_train[1],
                                           'train')
                    self._store_accuracies(logits_val, data_val[1], 'val')

                    log_every_n_seconds(
                        logging.INFO,
                        "Epoch {}-{}, Train loss: {:.5f}, validation loss: {:.5f}, learning rate: {}"
                        .format(e, step, train_loss, val_loss,
                                self.scheduler.get_last_lr()),
                        n=5)

                    if torch.cuda.is_available():
                        log_first_n(logging.INFO,
                                    "cuda consumption\n {}".format(
                                        torch.cuda.memory_summary()),
                                    n=3)

                    self.train_loss.update(float(train_loss.detach().cpu()))
                    self.val_loss.update(float(val_loss.detach().cpu()))

                self.scheduler.step()

                end_time = time.time()

                self.errors_dict.train_acc.append(self.train_top1.avg)
                self.errors_dict.train_loss.append(self.train_loss.avg)
                self.errors_dict.valid_acc.append(self.val_top1.avg)
                self.errors_dict.valid_loss.append(self.val_loss.avg)
                self.errors_dict.runtime.append(end_time - start_time)
            else:
                end_time = time.time()
                train_acc, train_loss, valid_acc, valid_loss = self.optimizer.train_statistics(
                )
                self.errors_dict.train_acc.append(train_acc)
                self.errors_dict.train_loss.append(train_loss)
                self.errors_dict.valid_acc.append(valid_acc)
                self.errors_dict.valid_loss.append(valid_loss)
                self.errors_dict.runtime.append(end_time - start_time)
                self.train_top1.avg = train_acc
                self.val_top1.avg = valid_acc

            self.periodic_checkpointer.step(e)

            anytime_results = self.optimizer.test_statistics()
            if anytime_results:
                # record anytime performance
                self.errors_dict.arch_eval.append(anytime_results)
                log_every_n_seconds(logging.INFO,
                                    "Epoch {}, Anytime results: {}".format(
                                        e, anytime_results),
                                    n=5)

            self._log_to_json()
            self._log_and_reset_accuracies(e)

        self.optimizer.after_training()
        logger.info("Training finished")
Exemplo n.º 3
0
    def evaluate(
        self,
        retrain=True,
        search_model="",
        resume_from="",
        best_arch=None,
    ):
        """
        Evaluate the final architecture as given from the optimizer.

        If the search space has an interface to a benchmark then query that.
        Otherwise train as defined in the config.

        Args:
            retrain (bool): Reset the weights from the architecure search
            search_model (str): Path to checkpoint file that was created during
                search. If not provided, then try to load 'model_final.pth' from search
            resume_from (str): Resume retraining from the given checkpoint file.
            best_arch: Parsed model you want to directly evaluate and ignore the final model
                from the optimizer.
        """
        logger.info("Start evaluation")
        if not best_arch:

            if not search_model:
                search_model = os.path.join(self.config.save, "search",
                                            "model_final.pth")
            self._setup_checkpointers(
                search_model)  # required to load the architecture

            best_arch = self.optimizer.get_final_architecture()
        logger.info("Final architecture:\n" + best_arch.modules_str())

        if best_arch.QUERYABLE:
            metric = Metric.TEST_ACCURACY
            result = best_arch.query(metric=metric,
                                     dataset=self.config.dataset)
            logger.info("Queried results ({}): {}".format(metric, result))
        else:
            best_arch.to(self.device)
            if retrain:
                logger.info("Starting retraining from scratch")
                best_arch.reset_weights(inplace=True)

                self.train_queue, self.valid_queue, self.test_queue = self.build_eval_dataloaders(
                    self.config)

                optim = self.build_eval_optimizer(best_arch.parameters(),
                                                  self.config)
                scheduler = self.build_eval_scheduler(optim, self.config)

                start_epoch = self._setup_checkpointers(
                    resume_from,
                    search=False,
                    period=self.config.evaluation.checkpoint_freq,
                    model=best_arch,  # checkpointables start here
                    optim=optim,
                    scheduler=scheduler)

                grad_clip = self.config.evaluation.grad_clip
                loss = torch.nn.CrossEntropyLoss()

                best_arch.train()
                self.train_top1.reset()
                self.train_top5.reset()
                self.val_top1.reset()
                self.val_top5.reset()

                # Enable drop path
                best_arch.update_edges(update_func=lambda edge: edge.data.set(
                    'op', DropPathWrapper(edge.data.op)),
                                       scope=best_arch.OPTIMIZER_SCOPE,
                                       private_edge_data=True)

                # train from scratch
                epochs = self.config.evaluation.epochs
                for e in range(start_epoch, epochs):
                    if torch.cuda.is_available():
                        log_first_n(logging.INFO,
                                    "cuda consumption\n {}".format(
                                        torch.cuda.memory_summary()),
                                    n=20)

                    # update drop path probability
                    drop_path_prob = self.config.evaluation.drop_path_prob * e / epochs
                    best_arch.update_edges(
                        update_func=lambda edge: edge.data.set(
                            'drop_path_prob', drop_path_prob),
                        scope=best_arch.OPTIMIZER_SCOPE,
                        private_edge_data=True)

                    # Train queue
                    for i, (input_train,
                            target_train) in enumerate(self.train_queue):
                        input_train = input_train.to(self.device)
                        target_train = target_train.to(self.device,
                                                       non_blocking=True)

                        optim.zero_grad()
                        logits_train = best_arch(input_train)
                        train_loss = loss(logits_train, target_train)
                        if hasattr(best_arch,
                                   'auxilary_logits'):  # darts specific stuff
                            log_first_n(logging.INFO,
                                        "Auxiliary is used",
                                        n=10)
                            auxiliary_loss = loss(best_arch.auxilary_logits(),
                                                  target_train)
                            train_loss += self.config.evaluation.auxiliary_weight * auxiliary_loss
                        train_loss.backward()
                        if grad_clip:
                            torch.nn.utils.clip_grad_norm_(
                                best_arch.parameters(), grad_clip)
                        optim.step()

                        self._store_accuracies(logits_train, target_train,
                                               'train')
                        log_every_n_seconds(
                            logging.INFO,
                            "Epoch {}-{}, Train loss: {:.5}, learning rate: {}"
                            .format(e, i, train_loss, scheduler.get_last_lr()),
                            n=5)

                    # Validation queue
                    if self.valid_queue:
                        for i, (input_valid,
                                target_valid) in enumerate(self.valid_queue):

                            input_valid = input_valid.cuda().float()
                            target_valid = target_valid.cuda().float()

                            # just log the validation accuracy
                            with torch.no_grad():
                                logits_valid = best_arch(input_valid)
                                self._store_accuracies(logits_valid,
                                                       target_valid, 'val')

                    scheduler.step()
                    self.periodic_checkpointer.step(e)
                    self._log_and_reset_accuracies(e)

            # Disable drop path
            best_arch.update_edges(update_func=lambda edge: edge.data.set(
                'op', edge.data.op.get_embedded_ops()),
                                   scope=best_arch.OPTIMIZER_SCOPE,
                                   private_edge_data=True)

            # measure final test accuracy
            top1 = utils.AverageMeter()
            top5 = utils.AverageMeter()

            best_arch.eval()

            for i, data_test in enumerate(self.test_queue):
                input_test, target_test = data_test
                input_test = input_test.to(self.device)
                target_test = target_test.to(self.device, non_blocking=True)

                n = input_test.size(0)

                with torch.no_grad():
                    logits = best_arch(input_test)

                    prec1, prec5 = utils.accuracy(logits,
                                                  target_test,
                                                  topk=(1, 5))
                    top1.update(prec1.data.item(), n)
                    top5.update(prec5.data.item(), n)

                log_every_n_seconds(logging.INFO,
                                    "Inference batch {} of {}.".format(
                                        i, len(self.test_queue)),
                                    n=5)

            logger.info(
                "Evaluation finished. Test accuracies: top-1 = {:.5}, top-5 = {:.5}"
                .format(top1.avg, top5.avg))
Exemplo n.º 4
0
    def forward(self, x, *args):
        """
        Forward some data through the graph. This is done recursively
        in case there are graphs defined on nodes or as 'op' on edges.

        Args:
            x (Tensor or dict): The input. If the graph sits on a node the
                input can be a dict with {source_idx: Tensor} to be routed
                to the defined input nodes. If the graph sits on an edge,
                x is the feature tensor.
            args: This is only required to handle cases where the graph sits
                on an edge and receives an EdgeData object which will be ignored
        """
        logger.debug("Graph {} called. Input {}.".format(
            self.name, log_formats(x)))

        # Assign x to the corresponding input nodes
        self._assign_x_to_nodes(x)

        for node_idx in lexicographical_topological_sort(self):
            node = self.nodes[node_idx]
            logger.debug(
                "Node {}-{}, current data {}, start processing...".format(
                    self.name, node_idx, log_formats(node)))

            # node internal: process input if necessary
            if ('subgraph' in node
                    and 'comb_op' not in node) or ('comb_op' in node
                                                   and 'subgraph' not in node):
                log_first_n(logging.WARN,
                            "Comb_op is ignored if subgraph is defined!",
                            n=1)
            # TODO: merge 'subgraph' and 'comb_op'. It is basicallly the same thing. Also in parse()
            if 'subgraph' in node:
                x = node['subgraph'].forward(node['input'])
            else:
                if len(node['input'].values()) == 1:
                    x = list(node['input'].values())[0]
                else:
                    x = node['comb_op']([
                        node['input'][k] for k in sorted(node['input'].keys())
                    ])
            node['input'] = {}  # clear the input as we have processed it

            if len(list(self.neighbors(node_idx))) == 0 and node_idx < list(
                    lexicographical_topological_sort(self))[-1]:
                # We have more than one output node. This is e.g. the case for
                # auxillary losses. Attach them to the graph, handling must done
                # by the user.
                logger.debug(
                    "Graph {} has more then one output node. Storing output of non-maximum index node {} at graph dict"
                    .format(self, node_idx))
                self.graph['out_from_{}'.format(node_idx)] = x
            else:
                # outgoing edges: process all outgoing edges
                for neigbor_idx in self.neighbors(node_idx):
                    edge_data = self.get_edge_data(node_idx, neigbor_idx)
                    # inject edge data only for AbstractPrimitive, not Graphs
                    if isinstance(edge_data.op, Graph):
                        edge_output = edge_data.op.forward(x)
                    elif isinstance(edge_data.op, AbstractPrimitive):
                        logger.debug("Processing op {} at edge {}-{}".format(
                            edge_data.op, node_idx, neigbor_idx))
                        edge_output = edge_data.op.forward(x,
                                                           edge_data=edge_data)
                    else:
                        raise ValueError(
                            "Unknown class as op: {}. Expected either Graph or AbstactPrimitive"
                            .format(edge_data.op))
                    self.nodes[neigbor_idx]['input'].update(
                        {node_idx: edge_output})

            logger.debug("Node {}-{}, processing done.".format(
                self.name, node_idx))

        logger.debug("Graph {} exiting. Output {}.".format(
            self.name, log_formats(x)))
        return x
Exemplo n.º 5
0
    def main_worker(self, gpu, ngpus_per_node, args, search_model, best_arch):
        logger.info("Start evaluation")
        if not best_arch:
            if not search_model:
                search_model = os.path.join(self.config.save, "search", "model_final.pth")
            self._setup_checkpointers(search_model)      # required to load the architecture

            best_arch = self.optimizer.get_final_architecture()
        logger.info("Final architecture:\n" + best_arch.modules_str())

        if best_arch.QUERYABLE:
            metric = Metric.TEST_ACCURACY
            result = best_arch.query(
                metric=metric, dataset=self.config.dataset
            )
            logger.info("Queried results ({}): {}".format(metric, result))
            self.QUERYABLE = True
            return

        best_arch.reset_weights(inplace=True)
        logger.info("Starting retraining from scratch")

        args.gpu = gpu
        if gpu is not None:
            logger.info("Use GPU: {} for training".format(args.gpu))

        if args.distributed:
            if args.dist_url == "env://" and args.rank == -1:
                args.rank = int(os.environ["RANK"])
            if args.multiprocessing_distributed:
                # For multiprocessing distributed training, rank needs to be the
                # global rank among all processes
                args.rank = args.rank * ngpus_per_node + gpu
            dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                    world_size=args.world_size, rank=args.rank)

        if not torch.cuda.is_available():
            logger.warning("Using CPU, this will be slow!")
        elif args.distributed:
            # For multiprocessing distributed, DistributedDataParallel constructor
            # should always set the single device scope, otherwise,
            # DistributedDataParallel will use all available devices
            if args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                best_arch.cuda(args.gpu)
                # When using a single GPU per process and per
                # DistributedDataParallel, we need to divide the batch size
                # ourselves based on the total number of GPUs we have
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.workers = int((args.workers + ngpus_per_node - 1) /
                                   ngpus_per_node)
                best_arch = \
                    torch.nn.parallel.DistributedDataParallel(best_arch,
                                                              device_ids=[args.gpu])
            else:
                best_arch.cuda()
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                best_arch = torch.nn.parallel.DistributedDataParallel(best_arch)
        elif args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            best_arch = best_arch.cuda(args.gpu)
        else:
            # DataParallel will divide and allocate batch_size to all available GPUs
            best_arch = torch.nn.DataParallel(best_arch).cuda()

        cudnn.benchmark = True

        self.train_queue, self.valid_queue, self.test_queue =\
            self.build_eval_dataloaders(self.config)

        optim = self.build_eval_optimizer(best_arch.parameters(), self.config)
        scheduler = self.build_eval_scheduler(optim, self.config)

        start_epoch = self._setup_checkpointers(args.resume_from,
            search=False,
            period=self.config.evaluation.checkpoint_freq,
            model=best_arch,    # checkpointables start here
            optim=optim,
            scheduler=scheduler
        )

        grad_clip = self.config.evaluation.grad_clip
        loss = torch.nn.CrossEntropyLoss()

        best_arch.train()
        self.train_top1.reset()
        self.train_top5.reset()
        self.val_top1.reset()
        self.val_top5.reset()

        # Enable drop path
        if isinstance(best_arch, torch.nn.DataParallel):
            best_arch.module.update_edges(
                update_func=lambda edge: edge.data.set('op', DropPathWrapper(edge.data.op)),
                scope=best_arch.module.OPTIMIZER_SCOPE,
                private_edge_data=True
            )
        else:
            best_arch.update_edges(
                update_func=lambda edge: edge.data.set('op', DropPathWrapper(edge.data.op)),
                scope=best_arch.OPTIMIZER_SCOPE,
                private_edge_data=True
            )

        # train from scratch
        epochs = self.config.evaluation.epochs
        for e in range(start_epoch, epochs):
            # update drop path probability
            drop_path_prob = self.config.evaluation.drop_path_prob * e / epochs
            if isinstance(best_arch, torch.nn.DataParallel):
                best_arch.module.update_edges(
                    update_func=lambda edge: edge.data.set('drop_path_prob', drop_path_prob),
                    scope=best_arch.module.OPTIMIZER_SCOPE,
                    private_edge_data=True
                )
            else:
                best_arch.update_edges(
                    update_func=lambda edge: edge.data.set('drop_path_prob', drop_path_prob),
                    scope=best_arch.OPTIMIZER_SCOPE,
                    private_edge_data=True
                )

            # Train queue
            for i, (input_train, target_train) in enumerate(self.train_queue):
                input_train = input_train.to(self.device)
                target_train = target_train.to(self.device, non_blocking=True)

                optim.zero_grad()
                logits_train = best_arch(input_train)
                train_loss = loss(logits_train, target_train)
                if hasattr(best_arch, 'auxilary_logits'):   # darts specific stuff
                    log_first_n(logging.INFO, "Auxiliary is used", n=10)
                    auxiliary_loss = loss(best_arch.auxilary_logits(), target_train)
                    train_loss += self.config.evaluation.auxiliary_weight * auxiliary_loss
                train_loss.backward()
                if grad_clip:
                    torch.nn.utils.clip_grad_norm_(best_arch.parameters(), grad_clip)
                optim.step()

                self._store_accuracies(logits_train, target_train, 'train')
                log_every_n_seconds(
                    logging.INFO,
                    "Epoch {}-{}, Train loss: {:.5}, learning rate: {}".format(
                        e, i, train_loss, scheduler.get_last_lr()
                    ), n=5
                )

                if torch.cuda.is_available():
                    log_first_n(
                        logging.INFO,
                        "cuda consumption\n {}".format(
                            torch.cuda.memory_summary()
                        ), n=3
                    )

            # Validation queue
            if self.valid_queue:
                for i, (input_valid, target_valid) in enumerate(self.valid_queue):

                    input_valid = input_valid.to(self.device).float()
                    target_valid = target_valid.to(self.device, non_blocking=True).float()

                    # just log the validation accuracy
                    logits_valid = best_arch(input_valid)
                    self._store_accuracies(logits_valid, target_valid, 'val')

            scheduler.step()
            self.periodic_checkpointer.step(e)
            self._log_and_reset_accuracies(e)