示例#1
0
    def search_model_desc_dist(searcher:'SearcherPetridish', conf_search:Config,
        hull_point:ConvexHullPoint, model_desc_builder:ModelDescBuilder,
        trainer_class:TArchTrainer, finalizers:Finalizers, common_state:CommonState)\
            ->ConvexHullPoint:

        # as this runs in different process, initiaze globals
        common.init_from(common_state)

        #register ops as we are in different process now
        conf_model_desc = conf_search['model_desc']
        model_desc_builder.pre_build(conf_model_desc)

        assert hull_point.is_trained_stage()

        # cloning is strictly not needed but just in case if we run this
        # function in same process, it would be good to avoid surprise
        model_desc = hull_point.model_desc.clone()
        searcher._add_node(model_desc, model_desc_builder)

        model_desc, search_metrics = searcher.search_model_desc(
            conf_search, model_desc, trainer_class, finalizers)

        cells, reductions, nodes = hull_point.cells_reductions_nodes
        new_point = ConvexHullPoint(
            JobStage.SEARCH,
            hull_point.id,
            hull_point.sampling_count,
            model_desc,
            (cells, reductions, nodes + 1),  # we added a node
            metrics=search_metrics)
        return new_point
示例#2
0
    def _create_seed_jobs(self, conf_search: Config,
                          model_desc_builder: ModelDescBuilder) -> list:
        conf_model_desc = conf_search['model_desc']
        conf_seed_train = conf_search['seed_train']

        future_ids = []  # ray job IDs
        macro_combinations = list(self.get_combinations(conf_search))
        for reductions, cells, nodes in macro_combinations:
            # if N R N R N R cannot be satisfied, ignore combination
            if cells < reductions * 2 + 1:
                continue

            # create seed model
            model_desc = self.build_model_desc(model_desc_builder,
                                               conf_model_desc, reductions,
                                               cells, nodes)

            hull_point = ConvexHullPoint(JobStage.SEED, 0, 0, model_desc,
                                         (cells, reductions, nodes))

            # pre-train the seed model
            future_id = SearcherPetridish.train_model_desc_dist.remote(
                self, conf_seed_train, hull_point, common.get_state())

            future_ids.append(future_id)

        return future_ids
示例#3
0
    def train_model_desc_dist(searcher:'SearcherPetridish', conf_train:Config,
                              hull_point:ConvexHullPoint, common_state:CommonState)\
            ->ConvexHullPoint:
        # as this runs in different process, initialize globals
        common.init_from(common_state)

        assert not hull_point.is_trained_stage()

        model_metrics = searcher.train_model_desc(hull_point.model_desc, conf_train)
        model_stats = nas_utils.get_model_stats(model_metrics.model)

        new_point = ConvexHullPoint(hull_point.next_stage(), hull_point.id, hull_point.
                                    sampling_count, hull_point.model_desc,
                                    hull_point.cells_reductions_nodes,
                                    model_metrics.metrics,
                                    model_stats)

        return new_point
示例#4
0
    def _update_convex_hull(self, new_point:ConvexHullPoint)->None:
        assert new_point.is_trained_stage() # only add models for which we have metrics and stats
        self._hull_points.append(new_point)

        if self._checkpoint is not None:
            self._checkpoint.new()
            self._checkpoint['convex_hull_points'] = self._hull_points
            self._checkpoint.commit()

        logger.info(f'Added to convex hull points: MAdd {new_point.model_stats.MAdd}, '
                    f'num cells {len(new_point.model_desc.cell_descs())}, '
                    f'num nodes in cell {len(new_point.model_desc.cell_descs()[0].nodes())}')
示例#5
0
    def _create_seed_jobs(self, conf_search:Config, model_desc_builder:ModelDescBuilder)->list:
        conf_model_desc = conf_search['model_desc']
        conf_seed_train = conf_search['seed_train']

        future_ids = [] # ray job IDs
        seed_model_stats = [] # seed model stats for visualization and debugging 
        macro_combinations = list(self.get_combinations(conf_search))
        for reductions, cells, nodes in macro_combinations:
            # if N R N R N R cannot be satisfied, ignore combination
            if cells < reductions * 2 + 1:
                continue

            # create seed model
            model_desc = self.build_model_desc(model_desc_builder,
                                               conf_model_desc,
                                               reductions, cells, nodes)

            hull_point = ConvexHullPoint(JobStage.SEED, 0, 0, model_desc,
                                         (cells, reductions, nodes))

            # pre-train the seed model
            future_id = SearcherPetridish.train_model_desc_dist.remote(self,
                conf_seed_train, hull_point, common.get_state())

            future_ids.append(future_id)

            # build a model so we can get its model stats
            temp_model = Model(model_desc, droppath=True, affine=True)
            seed_model_stats.append(nas_utils.get_model_stats(temp_model))
        
        # save the model stats in a plot and tsv file so we can
        # visualize the spread on the x-axis
        expdir = common.get_expdir()
        assert expdir
        plot_seed_model_stats(seed_model_stats, expdir)

        return future_ids
示例#6
0
    def _train_dist(evaluater: Evaluater, conf_eval: Config,
                    model_desc_builder: ModelDescBuilder,
                    model_desc_filename: str, common_state) -> ConvexHullPoint:
        """Train given a model"""

        common.init_from(common_state)

        # region config vars
        conf_model_desc = conf_eval['model_desc']
        max_cells = conf_model_desc['n_cells']

        conf_checkpoint = conf_eval['checkpoint']
        resume = conf_eval['resume']

        conf_petridish = conf_eval['petridish']
        cell_count_scale = conf_petridish['cell_count_scale']
        #endregion

        #register ops as we are in different process now
        model_desc_builder.pre_build(conf_model_desc)

        model_filename = utils.append_to_filename(model_desc_filename,
                                                  '_model', '.pt')
        full_desc_filename = utils.append_to_filename(model_desc_filename,
                                                      '_full', '.yaml')
        metrics_filename = utils.append_to_filename(model_desc_filename,
                                                    '_metrics', '.yaml')
        model_stats_filename = utils.append_to_filename(
            model_desc_filename, '_model_stats', '.yaml')

        # create checkpoint for this specific model desc by changing the config
        checkpoint = None
        if conf_checkpoint is not None:
            conf_checkpoint['filename'] = model_filename.split(
                '.')[0] + '_checkpoint.pth'
            checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)

            if checkpoint is not None and resume:
                if 'metrics_stats' in checkpoint:
                    # return the output we had recorded in the checkpoint
                    convex_hull_point = checkpoint['metrics_stats']
                    return convex_hull_point

        # template model is what we used during the search
        template_model_desc = ModelDesc.load(model_desc_filename)

        # we first scale this model by number of cells, keeping reductions same as in search
        n_cells = math.ceil(
            len(template_model_desc.cell_descs()) * cell_count_scale)
        n_cells = min(n_cells, max_cells)

        conf_model_desc = copy.deepcopy(conf_model_desc)
        conf_model_desc['n_cells'] = n_cells
        conf_model_desc[
            'n_reductions'] = n_reductions = template_model_desc.cell_type_count(
                CellType.Reduction)

        model_desc = model_desc_builder.build(conf_model_desc,
                                              template=template_model_desc)
        # save desc for reference
        model_desc.save(full_desc_filename)

        model = evaluater.model_from_desc(model_desc)

        train_metrics = evaluater.train_model(conf_eval, model, checkpoint)
        train_metrics.save(metrics_filename)

        # get metrics_stats
        model_stats = nas_utils.get_model_stats(model)
        # save metrics_stats
        with open(model_stats_filename, 'w') as f:
            yaml.dump(model_stats, f)

        # save model
        if model_filename:
            model_filename = utils.full_path(model_filename)
            ml_utils.save_model(model, model_filename)
            # TODO: Causes logging error at random times. Commenting out as stop-gap fix.
            # logger.info({'model_save_path': model_filename})

        hull_point = ConvexHullPoint(
            JobStage.EVAL_TRAINED,
            0,
            0,
            model_desc,
            (n_cells, n_reductions, len(model_desc.cell_descs()[0].nodes())),
            metrics=train_metrics,
            model_stats=model_stats)

        if checkpoint:
            checkpoint.new()
            checkpoint['metrics_stats'] = hull_point
            checkpoint.commit()

        return hull_point