Esempio n. 1
0
    def finalize_node(self, node:nn.ModuleList, node_index:int,
                      node_desc:NodeDesc, max_final_edges:int,
                      cov:np.array, cell: Cell, node_id: int,
                      *args, **kwargs)->NodeDesc:
        # node is a list of edges
        assert len(node) >= max_final_edges

        # covariance matrix shape must be square 2-D
        assert len(cov.shape) == 2
        assert cov.shape[0] == cov.shape[1]

        # the number of primitive operators has to be greater
        # than equal to the maximum number of final edges
        # allowed
        assert cov.shape[0] >= max_final_edges

        # get the order and alpha of all ops other than 'none'
        in_ops = [(edge,op,alpha,i) for i, edge in enumerate(node) \
                            for op, alpha in edge._op.ops()
                            if not isinstance(op, Zero)]
        assert len(in_ops) >= max_final_edges

        # order all the ops by alpha
        in_ops_sorted = sorted(in_ops, key=lambda in_op:in_op[2], reverse=True)

        # keep under consideration top half of the ops
        num_to_keep = max(max_final_edges, len(in_ops_sorted)//2)
        top_ops = in_ops_sorted[:num_to_keep]

        # get the covariance submatrix of the top ops only
        cov_inds = []
        for edge, op, alpha, edge_num in top_ops:
            ind = self._divnas_cells[cell].node_num_to_node_op_to_cov_ind[node_id][op]
            cov_inds.append(ind)

        cov_top_ops = cov[np.ix_(cov_inds, cov_inds)]

        assert len(cov_inds) == len(top_ops)
        assert len(top_ops) >= max_final_edges
        assert cov_top_ops.shape[0] == cov_top_ops.shape[1]
        assert len(cov_top_ops.shape) == 2

        # run brute force set selection algorithm
        # only on the top ops
        max_subset, max_mi = compute_brute_force_sol(cov_top_ops, max_final_edges)

        # note that elements of max_subset are indices into top_ops only
        selected_edges = []
        for ind in max_subset:
            edge, op, alpha, edge_num = top_ops[ind]
            op_desc, _ = op.finalize()
            new_edge = EdgeDesc(op_desc, edge.input_ids)
            logger.info(f'selected edge: {edge_num}, op: {op_desc.name}')
            selected_edges.append(new_edge)

        # save diagnostic information to disk
        expdir = get_expdir()
        sns.heatmap(cov_top_ops, annot=True, fmt='.1g', cmap='coolwarm')
        savename = os.path.join(
            expdir, f'cell_{cell.desc.id}_node_{node_id}_cov.png')
        plt.savefig(savename)

        logger.info('')
        return NodeDesc(selected_edges, node_desc.conv_params)
Esempio n. 2
0
    def search(self, conf_search: Config, model_desc_builder: ModelDescBuilder,
               trainer_class: TArchTrainer,
               finalizers: Finalizers) -> SearchResult:

        logger.pushd('search')

        # region config vars
        self.conf_search = conf_search
        conf_checkpoint = conf_search['checkpoint']
        resume = conf_search['resume']

        conf_post_train = conf_search['post_train']
        final_desc_foldername = conf_search['final_desc_foldername']

        conf_petridish = conf_search['petridish']

        # petridish distributed search related parameters
        self._convex_hull_eps = conf_petridish['convex_hull_eps']
        self._sampling_max_try = conf_petridish['sampling_max_try']
        self._max_madd = conf_petridish['max_madd']
        self._max_hull_points = conf_petridish['max_hull_points']
        self._checkpoints_foldername = conf_petridish['checkpoints_foldername']
        # endregion

        self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)

        # parent models list
        self._hull_points: List[ConvexHullPoint] = []

        self._ensure_dataset_download(conf_search)

        # checkpoint will restore the hull we had
        is_restored = self._restore_checkpoint()

        # seed the pool with many seed models of different
        # macro parameters like number of cells, reductions etc if parent pool
        # could not be restored and/or this is the first time this job has been run.
        future_ids = [] if is_restored else self._create_seed_jobs(
            conf_search, model_desc_builder)

        while not self._is_search_done():
            logger.info(f'Ray jobs running: {len(future_ids)}')

            if future_ids:
                # get first completed job
                job_id_done, future_ids = ray.wait(future_ids)

                hull_point = ray.get(job_id_done[0])

                logger.info(
                    f'Hull point id {hull_point.id} with stage {hull_point.job_stage.name} completed'
                )

                if hull_point.is_trained_stage():
                    self._update_convex_hull(hull_point)

                    # sample a point and search
                    sampled_point = sample_from_hull(self._hull_points,
                                                     self._convex_hull_eps,
                                                     self._sampling_max_try)

                    future_id = SearcherPetridish.search_model_desc_dist.remote(
                        self, conf_search, sampled_point, model_desc_builder,
                        trainer_class, finalizers, common.get_state())
                    future_ids.append(future_id)
                    logger.info(
                        f'Added sampled point {sampled_point.id} for search')
                elif hull_point.job_stage == JobStage.SEARCH:
                    # create the job to train the searched model
                    future_id = SearcherPetridish.train_model_desc_dist.remote(
                        self, conf_post_train, hull_point, common.get_state())
                    future_ids.append(future_id)
                    logger.info(
                        f'Added sampled point {hull_point.id} for post-search training'
                    )
                else:
                    raise RuntimeError(
                        f'Job stage "{hull_point.job_stage}" is not expected in search loop'
                    )

        # cancel any remaining jobs to free up gpus for the eval phase
        for future_id in future_ids:
            ray.cancel(future_id,
                       force=True)  # without force, main process stops
            ray.wait([future_id])

        # plot and save the hull
        expdir = common.get_expdir()
        assert expdir
        plot_frontier(self._hull_points, self._convex_hull_eps, expdir)
        best_point = save_hull_frontier(self._hull_points,
                                        self._convex_hull_eps,
                                        final_desc_foldername, expdir)
        save_hull(self._hull_points, expdir)
        plot_pool(self._hull_points, expdir)

        # return best point as search result
        search_result = SearchResult(best_point.model_desc,
                                     search_metrics=None,
                                     train_metrics=best_point.metrics)
        self.clean_log_result(conf_search, search_result)

        logger.popd()

        return search_result
Esempio n. 3
0
    def fit(self, data_loaders: data.DataLoaders) -> Metrics:
        logger.pushd(self._title)

        assert data_loaders.train_dl is not None

        self._metrics = Metrics(self._title,
                                self._apex,
                                logger_freq=self._logger_freq)

        # create optimizers and schedulers
        self._multi_optim = self.create_multi_optim(len(data_loaders.train_dl))
        # before checkpoint restore, convert to amp
        self.model = self._apex.to_amp(
            self.model,
            self._multi_optim,
            batch_size=data_loaders.train_dl.batch_size)

        self._lossfn = self._lossfn.to(self.get_device())

        self.pre_fit(data_loaders)

        # we need to restore checkpoint after all objects are created because
        # restoring checkpoint requires load_state_dict calls on these objects
        self._start_epoch = 0
        # do we have a checkpoint
        checkpoint_avail = self._checkpoint is not None
        checkpoint_val = checkpoint_avail and 'trainer' in self._checkpoint
        resumed = False
        if checkpoint_val:
            # restore checkpoint
            resumed = True
            self.restore_checkpoint()
        elif checkpoint_avail:  # TODO: bad checkpoint?
            self._checkpoint.clear()
        logger.warn({
            'resumed': resumed,
            'checkpoint_avail': checkpoint_avail,
            'checkpoint_val': checkpoint_val,
            'start_epoch': self._start_epoch,
            'total_epochs': self._epochs
        })
        logger.info({
            'aux_weight': self._aux_weight,
            'grad_clip': self._grad_clip,
            'drop_path_prob': self._drop_path_prob,
            'validation_freq': self._validation_freq,
            'batch_chunks': self.batch_chunks
        })

        if self._start_epoch >= self._epochs:
            logger.warn(
                f'fit done because start_epoch {self._start_epoch}>={self._epochs}'
            )
            return self.get_metrics(
            )  # we already finished the run, we might be checkpointed

        logger.pushd('epochs')
        for epoch in range(self._start_epoch, self._epochs):
            logger.pushd(epoch)
            self._set_epoch(epoch, data_loaders)
            self.pre_epoch(data_loaders)
            self._train_epoch(data_loaders.train_dl)
            self.post_epoch(data_loaders)
            logger.popd()
        logger.popd()
        self.post_fit(data_loaders)

        # make sure we don't keep references to the graph
        del self._multi_optim

        logger.popd()
        return self.get_metrics()
Esempio n. 4
0
    def _train_dist(evaluater: Evaluater, conf_eval: Config,
                    model_desc_builder: ModelDescBuilder,
                    model_desc_filename: str, common_state) -> ConvexHullPoint:
        """Train given a model"""

        common.init_from(common_state)

        # region config vars
        conf_model_desc = conf_eval['model_desc']
        max_cells = conf_model_desc['n_cells']

        conf_checkpoint = conf_eval['checkpoint']
        resume = conf_eval['resume']

        conf_petridish = conf_eval['petridish']
        cell_count_scale = conf_petridish['cell_count_scale']
        #endregion

        #register ops as we are in different process now
        model_desc_builder.pre_build(conf_model_desc)

        model_filename = utils.append_to_filename(model_desc_filename,
                                                  '_model', '.pt')
        full_desc_filename = utils.append_to_filename(model_desc_filename,
                                                      '_full', '.yaml')
        metrics_filename = utils.append_to_filename(model_desc_filename,
                                                    '_metrics', '.yaml')
        model_stats_filename = utils.append_to_filename(
            model_desc_filename, '_model_stats', '.yaml')

        # DEBUG
        print(f'received {model_desc_filename}')

        # create checkpoint for this specific model desc by changing the config
        checkpoint = None
        if conf_checkpoint is not None:
            conf_checkpoint['filename'] = model_filename.split(
                '.')[0] + '_checkpoint.pth'
            checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)

            if checkpoint is not None and resume:
                if 'metrics_stats' in checkpoint:
                    # return the output we had recorded in the checkpoint
                    convex_hull_point = checkpoint['metrics_stats']
                    return convex_hull_point

        # template model is what we used during the search
        template_model_desc = ModelDesc.load(model_desc_filename)

        # we first scale this model by number of cells, keeping reductions same as in search
        n_cells = math.ceil(
            len(template_model_desc.cell_descs()) * cell_count_scale)
        n_cells = min(n_cells, max_cells)

        # DEBUG
        print(
            f'{model_desc_filename} has {len(template_model_desc.cell_descs())} cells, scaling to {n_cells} cells via {cell_count_scale} factor'
        )

        conf_model_desc = copy.deepcopy(conf_model_desc)
        conf_model_desc['n_cells'] = n_cells
        conf_model_desc[
            'n_reductions'] = n_reductions = template_model_desc.cell_type_count(
                CellType.Reduction)

        model_desc = model_desc_builder.build(conf_model_desc,
                                              template=template_model_desc)
        # save desc for reference
        model_desc.save(full_desc_filename)

        model = evaluater.model_from_desc(model_desc)

        train_metrics = evaluater.train_model(conf_eval, model, checkpoint)
        train_metrics.save(metrics_filename)

        # get metrics_stats
        model_stats = nas_utils.get_model_stats(model)
        # save metrics_stats
        with open(model_stats_filename, 'w') as f:
            yaml.dump(model_stats, f)

        # save model
        if model_filename:
            model_filename = utils.full_path(model_filename)
            ml_utils.save_model(model, model_filename)
            logger.info({'model_save_path': model_filename})

        hull_point = ConvexHullPoint(
            JobStage.EVAL_TRAINED,
            0,
            0,
            model_desc,
            (n_cells, n_reductions, len(model_desc.cell_descs()[0].nodes())),
            metrics=train_metrics,
            model_stats=model_stats)

        if checkpoint:
            checkpoint.new()
            checkpoint['metrics_stats'] = hull_point
            checkpoint.commit()

        return hull_point