Beispiel #1
0
    def _train_desc(self, model_desc:ModelDesc, conf_train:Config)->MetricsStats:
        """Train given description"""
        # region conf vars
        conf_trainer = conf_train['trainer']
        conf_loader = conf_train['loader']
        trainer_title = conf_trainer['title']
        epochs = conf_trainer['epochs']
        drop_path_prob = conf_trainer['drop_path_prob']
        # endregion

        logger.pushd(trainer_title)

        if epochs == 0:
            # nothing to pretrain, save time
            metrics_stats = MetricsStats(model_desc, None, None)
        else:
            model = nas_utils.model_from_desc(model_desc,
                                              droppath=drop_path_prob>0.0,
                                              affine=True)

            # get data
            train_dl, val_dl = self.get_data(conf_loader)
            assert train_dl is not None

            trainer = Trainer(conf_trainer, model, checkpoint=None)
            train_metrics = trainer.fit(train_dl, val_dl)

            metrics_stats = Search._create_metrics_stats(model, train_metrics, self.finalizers)

        logger.popd()
        return metrics_stats
Beispiel #2
0
    def _search_desc(self, model_desc:ModelDesc, search_iter:int)->ModelDesc:
        logger.pushd('arch_search')

        nas_utils.build_cell(model_desc, self.cell_builder, search_iter)

        if self.trainer_class:
            model = nas_utils.model_from_desc(model_desc,
                                              droppath=False,
                                              affine=False)

            # get data
            train_dl, val_dl = self.get_data(self.conf_loader)
            assert train_dl is not None

            # search arch
            arch_trainer = self.trainer_class(self.conf_train, model, checkpoint=None)
            train_metrics = arch_trainer.fit(train_dl, val_dl)

            metrics_stats = Search._create_metrics_stats(model, train_metrics, self.finalizers)
            found_desc = metrics_stats.model_desc
        else: # if no trainer needed, for example, for random search
            found_desc = model_desc

        logger.popd()
        return found_desc
Beispiel #3
0
    def generate_pareto(self)->ModelDesc:
        macro_combinations = list(self._macro_combinations())
        start_macro, best_result = self._restore_checkpoint(macro_combinations)

        for macro_comb_i in range(start_macro, len(macro_combinations)):
            reductions, cells, nodes = macro_combinations[macro_comb_i]
            logger.pushd(f'r{reductions}.c{cells}.n{nodes}')

            model_desc = self._build_macro(reductions, cells, nodes)

            # prep seed model and train it
            model_desc = self._seed_model(model_desc, reductions, cells, nodes)

            model_desc, best_result = self._search_iters(model_desc, best_result,
                                                         reductions, cells, nodes)

            assert best_result is not None
            self._record_checkpoint(macro_comb_i, best_result)
            logger.popd() # reductions, cells, nodes


        assert best_result is not None
        best_result.model_desc().clear_trainables()
        logger.info({'best_macro_params':best_result.macro_params,
                     'best_metric':best_result.metrics_stats})
        best_result.model_desc().save(self.final_desc_filename)

        return best_result.model_desc()
Beispiel #4
0
    def _train_epoch(self, train_dl: DataLoader) -> None:
        steps = len(train_dl)
        self.model.train()

        logger.pushd('steps')
        for step, (x, y) in enumerate(train_dl):
            logger.pushd(step)
            assert self.model.training  # derived class might alter the mode

            # TODO: please check that no algorithm is invalidated by swapping prestep with zero grad
            self._multi_optim.zero_grad()

            self.pre_step(x, y)

            # divide batch in to chunks if needed so it fits in GPU RAM
            if self.batch_chunks > 1:
                x_chunks, y_chunks = torch.chunk(
                    x, self.batch_chunks), torch.chunk(y, self.batch_chunks)
            else:
                x_chunks, y_chunks = (x, ), (y, )

            logits_chunks = []
            loss_sum, loss_count = 0.0, 0
            for xc, yc in zip(x_chunks, y_chunks):
                xc, yc = xc.to(self.get_device(),
                               non_blocking=True), yc.to(self.get_device(),
                                                         non_blocking=True)

                logits_c, aux_logits = self.model(xc), None
                tupled_out = isinstance(logits_c, Tuple) and len(logits_c) >= 2
                # if self._aux_weight: # TODO: some other way to validate?
                #     assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
                if tupled_out:  # then we are using model created by desc
                    logits_c, aux_logits = logits_c[0], logits_c[1]
                loss_c = self.compute_loss(self._lossfn, yc, logits_c,
                                           self._aux_weight, aux_logits)

                self._apex.backward(loss_c, self._multi_optim)

                loss_sum += loss_c.item() * len(logits_c)
                loss_count += len(logits_c)
                logits_chunks.append(logits_c.detach().cpu())

            # TODO: original darts clips alphas as well but pt.darts doesn't
            self._apex.clip_grad(self._grad_clip, self.model,
                                 self._multi_optim)

            self._multi_optim.step()

            # TODO: we possibly need to sync so all replicas are upto date
            self._apex.sync_devices()

            self.post_step(x, y, ml_utils.join_chunks(logits_chunks),
                           torch.tensor(loss_sum / loss_count), steps)
            logger.popd()

            # end of step

        self._multi_optim.epoch()
        logger.popd()
Beispiel #5
0
def eval_arch(conf_eval: Config, cell_builder: Optional[CellBuilder]):
    logger.pushd('eval_arch')

    # region conf vars
    conf_loader = conf_eval['loader']
    model_filename = conf_eval['model_filename']
    metric_filename = conf_eval['metric_filename']
    conf_checkpoint = conf_eval['checkpoint']
    resume = conf_eval['resume']
    conf_train = conf_eval['trainer']
    # endregion

    if cell_builder:
        cell_builder.register_ops()

    model = create_model(conf_eval)

    # get data
    train_dl, _, test_dl = data.get_data(conf_loader)
    assert train_dl is not None and test_dl is not None

    checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
    trainer = Trainer(conf_train, model, checkpoint)
    train_metrics = trainer.fit(train_dl, test_dl)
    train_metrics.save(metric_filename)

    # save model
    if model_filename:
        model_filename = utils.full_path(model_filename)
        ml_utils.save_model(model, model_filename)

    logger.info({'model_save_path': model_filename})

    logger.popd()
Beispiel #6
0
    def train_model_desc(self, model_desc:ModelDesc, conf_train:Config)\
            ->Optional[ModelMetrics]:
        """Train given description"""

        # region conf vars
        conf_trainer = conf_train['trainer']
        conf_loader = conf_train['loader']
        trainer_title = conf_trainer['title']
        epochs = conf_trainer['epochs']
        drop_path_prob = conf_trainer['drop_path_prob']
        # endregion

        # if epochs ==0 then nothing to train, so save time
        if epochs <= 0:
            return None

        logger.pushd(trainer_title)

        model = Model(model_desc, droppath=drop_path_prob > 0.0, affine=True)

        # get data
        data_loaders = self.get_data(conf_loader)

        trainer = Trainer(conf_trainer, model, checkpoint=None)
        train_metrics = trainer.fit(data_loaders)

        logger.popd()

        return ModelMetrics(model, train_metrics)
Beispiel #7
0
    def evaluate(self, conf_eval: Config,
                 model_desc_builder: ModelDescBuilder) -> EvalResult:
        logger.pushd('eval_arch')

        # region conf vars
        conf_checkpoint = conf_eval['checkpoint']
        resume = conf_eval['resume']

        model_filename = conf_eval['model_filename']
        metric_filename = conf_eval['metric_filename']
        # endregion

        model = self.create_model(conf_eval, model_desc_builder)

        checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
        train_metrics = self.train_model(conf_eval, model, checkpoint)
        train_metrics.save(metric_filename)

        # save model
        if model_filename:
            model_filename = utils.full_path(model_filename)
            ml_utils.save_model(model, model_filename)

        logger.info({'model_save_path': model_filename})

        logger.popd()

        return EvalResult(train_metrics)
Beispiel #8
0
    def search_model_desc(self, conf_search:Config, model_desc:ModelDesc,
                          trainer_class:TArchTrainer, finalizers:Finalizers)\
                              ->Tuple[ModelDesc, Optional[Metrics]]:

        # if trainer is not specified for algos like random search we return same desc
        if trainer_class is None:
            return model_desc, None

        logger.pushd('arch_search')

        conf_trainer = conf_search['trainer']
        conf_loader = conf_search['loader']

        model = Model(model_desc, droppath=False, affine=False)

        # get data
        data_loaders = self.get_data(conf_loader)

        # search arch
        arch_trainer = trainer_class(conf_trainer, model, checkpoint=None)
        search_metrics = arch_trainer.fit(data_loaders)

        # finalize
        found_desc = self.finalize_model(model, finalizers)

        logger.popd()

        return found_desc, search_metrics
Beispiel #9
0
    def search(self, conf_search: Config, model_desc_builder: ModelDescBuilder,
               trainer_class: TArchTrainer,
               finalizers: Finalizers) -> SearchResult:

        # region config vars
        conf_model_desc = conf_search['model_desc']
        conf_post_train = conf_search['post_train']

        conf_checkpoint = conf_search['checkpoint']
        resume = conf_search['resume']
        # endregion

        self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)

        macro_combinations = list(self.get_combinations(conf_search))
        start_macro_i, best_search_result = self.restore_checkpoint(
            conf_search, macro_combinations)
        best_macro_comb = -1, -1, -1  # reductions, cells, nodes

        for macro_comb_i in range(start_macro_i, len(macro_combinations)):
            reductions, cells, nodes = macro_combinations[macro_comb_i]
            logger.pushd(f'r{reductions}.c{cells}.n{nodes}')

            # build model description that we will search on
            model_desc = self.build_model_desc(model_desc_builder,
                                               conf_model_desc, reductions,
                                               cells, nodes)

            # perform search on model description
            model_desc, search_metrics = self.search_model_desc(
                conf_search, model_desc, trainer_class, finalizers)

            # train searched model for few epochs to get some perf metrics
            model_metrics = self.train_model_desc(model_desc, conf_post_train)

            assert model_metrics is not None, "'post_train' section in yaml should have non-zero epochs if running combinations search"

            # save result
            self.save_trained(conf_search, reductions, cells, nodes,
                              model_metrics)

            # update the best result so far
            if self.is_better_metrics(best_search_result.search_metrics,
                                      model_metrics.metrics):
                best_search_result = SearchResult(model_desc, search_metrics,
                                                  model_metrics.metrics)
                best_macro_comb = reductions, cells, nodes

            # checkpoint
            assert best_search_result is not None
            self.record_checkpoint(macro_comb_i, best_search_result)
            logger.popd()  # reductions, cells, nodes

        assert best_search_result is not None
        self.clean_log_result(conf_search, best_search_result)
        logger.info({'best_macro_comb': best_macro_comb})

        return best_search_result
Beispiel #10
0
    def finalize_model(self,
                       model: Model,
                       to_cpu=True,
                       restore_device=True) -> ModelDesc:

        logger.pushd('finalize')

        # get config and train data loader
        # TODO: confirm this is correct in case you get silent bugs
        conf = get_conf()
        conf_loader = conf['nas']['search']['loader']
        train_dl, val_dl, test_dl = get_data(conf_loader)

        # wrap all cells in the model
        self._divnas_cells: Dict[int, Divnas_Cell] = {}
        for _, cell in enumerate(model.cells):
            divnas_cell = Divnas_Cell(cell)
            self._divnas_cells[id(cell)] = divnas_cell

        # go through all edges in the DAG and if they are of divop
        # type then set them to collect activations
        sigma = conf['nas']['search']['divnas']['sigma']
        for _, dcell in enumerate(self._divnas_cells.values()):
            dcell.collect_activations(DivOp, sigma)

        # now we need to run one evaluation epoch to collect activations
        # we do it on cpu otherwise we might run into memory issues
        # later we can redo the whole logic in pytorch itself
        # at the end of this each node in a cell will have the covariance
        # matrix of all incoming edges' ops
        model = model.cpu()
        model.eval()
        with torch.no_grad():
            for _ in range(1):
                for _, (x, _) in enumerate(train_dl):
                    _, _ = model(x), None
                    # now you can go through and update the
                    # node covariances in every cell
                    for dcell in self._divnas_cells.values():
                        dcell.update_covs()

        logger.popd()

        return super().finalize_model(model, to_cpu, restore_device)
Beispiel #11
0
    def evaluate(self, conf_eval: Config,
                 model_desc_builder: ModelDescBuilder) -> EvalResult:
        """Takes a folder of model descriptions output by search process and
        trains them in a distributed manner using ray with 1 gpu"""

        logger.pushd('evaluate')

        final_desc_foldername: str = conf_eval['final_desc_foldername']

        # get list of model descs in the gallery folder
        final_desc_folderpath = utils.full_path(final_desc_foldername)
        files = [os.path.join(final_desc_folderpath, f) \
                for f in glob.glob(os.path.join(final_desc_folderpath, 'model_desc_*.yaml')) \
                    if os.path.isfile(os.path.join(final_desc_folderpath, f))]
        logger.info({'model_desc_files': len(files)})

        # to avoid all workers download datasets individually, let's do it before hand
        self._ensure_dataset_download(conf_eval)

        future_ids = []
        for model_desc_filename in files:
            future_id = EvaluaterPetridish._train_dist.remote(
                self, conf_eval, model_desc_builder, model_desc_filename,
                common.get_state())
            future_ids.append(future_id)

        # wait for all eval jobs to be finished
        ready_refs, remaining_refs = ray.wait(future_ids,
                                              num_returns=len(future_ids))

        # plot pareto curve of gallery of models
        hull_points = [ray.get(ready_ref) for ready_ref in ready_refs]
        save_hull(hull_points, common.get_expdir())
        plot_pool(hull_points, common.get_expdir())

        best_point = max(hull_points, key=lambda p: p.metrics.best_val_top1())
        logger.info({
            'best_val_top1': best_point.metrics.best_val_top1(),
            'best_MAdd': best_point.model_stats.MAdd
        })

        logger.popd()

        return EvalResult(best_point.metrics)
Beispiel #12
0
    def __init__(self,
                 conf_train: Config,
                 model: nn.Module,
                 checkpoint: Optional[CheckPoint] = None) -> None:
        # region config vars
        self.conf_train = conf_train
        conf_lossfn = conf_train['lossfn']
        self._aux_weight = conf_train['aux_weight']
        self._grad_clip = conf_train['grad_clip']
        self._drop_path_prob = conf_train['drop_path_prob']
        self._logger_freq = conf_train['logger_freq']
        self._title = conf_train['title']
        self._epochs = conf_train['epochs']
        self.conf_optim = conf_train['optimizer']
        self.conf_sched = conf_train['lr_schedule']
        self.batch_chunks = conf_train['batch_chunks']
        conf_validation = conf_train['validation']
        conf_apex = conf_train['apex']
        self._validation_freq = 0 if conf_validation is None else conf_validation[
            'freq']
        # endregion

        logger.pushd(self._title + '__init__')

        self._apex = ApexUtils(conf_apex, logger)

        self._checkpoint = checkpoint
        self.model = model

        self._lossfn = ml_utils.get_lossfn(conf_lossfn)
        # using separate apex for Tester is not possible because we must use
        # same distributed model as Trainer and hence they must share apex
        self._tester = Tester(conf_validation, model, self._apex) \
                        if conf_validation else None
        self._metrics: Optional[Metrics] = None

        self._droppath_module = self._get_droppath_module()
        if self._droppath_module is None and self._drop_path_prob > 0.0:
            logger.warn({'droppath_module': None})

        self._start_epoch = -1  # nothing is started yet

        logger.popd()
Beispiel #13
0
    def _search_iters(self, model_desc:ModelDesc, best_result:Optional[SearchResult],
                      reductions:int, cells:int, nodes:int)->\
                          Tuple[ModelDesc, Optional[SearchResult]]:
        for search_iter in range(self.search_iters):
            logger.pushd(f'{search_iter}')

            # execute search iteration followed by training the model
            model_desc = self._search_desc(model_desc, search_iter)
            metrics_stats = self._train_desc(model_desc, self.conf_postsearch_train)
            model_desc = metrics_stats.model_desc

            # save result
            self._save_trained(reductions, cells, nodes, search_iter, metrics_stats)
            if metrics_stats.is_better(best_result.metrics_stats \
                                       if best_result is not None else None):
                best_result = SearchResult(metrics_stats,(reductions, cells, nodes))

            logger.popd() # search_iter

        return model_desc, best_result
Beispiel #14
0
    def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
        logger.pushd(self._title)

        self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)

        # create optimizers and schedulers
        self._multi_optim = self.create_multi_optim(len(train_dl))
        # before checkpoint restore, convert to amp
        self.model = self._apex.to_amp(self.model, self._multi_optim,
                                       batch_size=train_dl.batch_size)

        self._lossfn = self._lossfn.to(self.get_device())

        self.pre_fit(train_dl, val_dl)

        # we need to restore checkpoint after all objects are created because
        # restoring checkpoint requires load_state_dict calls on these objects
        self._start_epoch = 0
        # do we have a checkpoint
        checkpoint_avail = self._checkpoint is not None
        checkpoint_val = checkpoint_avail and 'trainer' in self._checkpoint
        resumed = False
        if checkpoint_val:
            # restore checkpoint
            resumed = True
            self.restore_checkpoint()
        elif checkpoint_avail: # TODO: bad checkpoint?
            self._checkpoint.clear()
        logger.warn({'resumed': resumed, 'checkpoint_avail': checkpoint_avail,
                     'checkpoint_val': checkpoint_val,
                     'start_epoch': self._start_epoch,
                     'total_epochs': self._epochs})
        logger.info({'aux_weight': self._aux_weight,
                     'grad_clip': self._grad_clip,
                     'drop_path_prob': self._drop_path_prob,
                     'validation_freq': self._validation_freq,
                     'batch_chunks': self.batch_chunks})

        if self._start_epoch >= self._epochs:
            logger.warn(f'fit done because start_epoch {self._start_epoch}>={self._epochs}')
            return self.get_metrics() # we already finished the run, we might be checkpointed

        logger.pushd('epochs')
        for epoch in range(self._start_epoch, self._epochs):
            logger.pushd(epoch)
            self._set_epoch(epoch, train_dl, val_dl)
            self.pre_epoch(train_dl, val_dl)
            self._train_epoch(train_dl)
            self.post_epoch(train_dl, val_dl)
            logger.popd()
        logger.popd()
        self.post_fit(train_dl, val_dl)

        # make sure we don't keep references to the graph
        del self._multi_optim

        logger.popd()
        return self.get_metrics()
Beispiel #15
0
    def search(self, conf_search: Config, model_desc_builder: ModelDescBuilder,
               trainer_class: TArchTrainer,
               finalizers: Finalizers) -> SearchResult:

        logger.pushd('search')

        # region config vars
        self.conf_search = conf_search
        conf_checkpoint = conf_search['checkpoint']
        resume = conf_search['resume']

        conf_post_train = conf_search['post_train']
        final_desc_foldername = conf_search['final_desc_foldername']

        conf_petridish = conf_search['petridish']

        # petridish distributed search related parameters
        self._convex_hull_eps = conf_petridish['convex_hull_eps']
        self._sampling_max_try = conf_petridish['sampling_max_try']
        self._max_madd = conf_petridish['max_madd']
        self._max_hull_points = conf_petridish['max_hull_points']
        self._checkpoints_foldername = conf_petridish['checkpoints_foldername']
        # endregion

        self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)

        # parent models list
        self._hull_points: List[ConvexHullPoint] = []

        self._ensure_dataset_download(conf_search)

        # checkpoint will restore the hull we had
        is_restored = self._restore_checkpoint()

        # seed the pool with many seed models of different
        # macro parameters like number of cells, reductions etc if parent pool
        # could not be restored and/or this is the first time this job has been run.
        future_ids = [] if is_restored else self._create_seed_jobs(
            conf_search, model_desc_builder)

        while not self._is_search_done():
            logger.info(f'Ray jobs running: {len(future_ids)}')

            if future_ids:
                # get first completed job
                job_id_done, future_ids = ray.wait(future_ids)

                hull_point = ray.get(job_id_done[0])

                logger.info(
                    f'Hull point id {hull_point.id} with stage {hull_point.job_stage.name} completed'
                )

                if hull_point.is_trained_stage():
                    self._update_convex_hull(hull_point)

                    # sample a point and search
                    sampled_point = sample_from_hull(self._hull_points,
                                                     self._convex_hull_eps,
                                                     self._sampling_max_try)

                    future_id = SearcherPetridish.search_model_desc_dist.remote(
                        self, conf_search, sampled_point, model_desc_builder,
                        trainer_class, finalizers, common.get_state())
                    future_ids.append(future_id)
                    logger.info(
                        f'Added sampled point {sampled_point.id} for search')
                elif hull_point.job_stage == JobStage.SEARCH:
                    # create the job to train the searched model
                    future_id = SearcherPetridish.train_model_desc_dist.remote(
                        self, conf_post_train, hull_point, common.get_state())
                    future_ids.append(future_id)
                    logger.info(
                        f'Added sampled point {hull_point.id} for post-search training'
                    )
                else:
                    raise RuntimeError(
                        f'Job stage "{hull_point.job_stage}" is not expected in search loop'
                    )

        # cancel any remaining jobs to free up gpus for the eval phase
        for future_id in future_ids:
            ray.cancel(future_id,
                       force=True)  # without force, main process stops
            ray.wait([future_id])

        # plot and save the hull
        expdir = common.get_expdir()
        assert expdir
        plot_frontier(self._hull_points, self._convex_hull_eps, expdir)
        best_point = save_hull_frontier(self._hull_points,
                                        self._convex_hull_eps,
                                        final_desc_foldername, expdir)
        save_hull(self._hull_points, expdir)
        plot_pool(self._hull_points, expdir)

        # return best point as search result
        search_result = SearchResult(best_point.model_desc,
                                     search_metrics=None,
                                     train_metrics=best_point.metrics)
        self.clean_log_result(conf_search, search_result)

        logger.popd()

        return search_result