Example #1
0
    def search_model_desc_dist(searcher:'SearcherPetridish', conf_search:Config,
        hull_point:ConvexHullPoint, model_desc_builder:ModelDescBuilder,
        trainer_class:TArchTrainer, finalizers:Finalizers, common_state:CommonState)\
            ->ConvexHullPoint:

        # as this runs in different process, initiaze globals
        common.init_from(common_state)

        #register ops as we are in different process now
        conf_model_desc = conf_search['model_desc']
        model_desc_builder.pre_build(conf_model_desc)

        assert hull_point.is_trained_stage()

        # cloning is strictly not needed but just in case if we run this
        # function in same process, it would be good to avoid surprise
        model_desc = hull_point.model_desc.clone()
        searcher._add_node(model_desc, model_desc_builder)

        model_desc, search_metrics = searcher.search_model_desc(
            conf_search, model_desc, trainer_class, finalizers)

        cells, reductions, nodes = hull_point.cells_reductions_nodes
        new_point = ConvexHullPoint(
            JobStage.SEARCH,
            hull_point.id,
            hull_point.sampling_count,
            model_desc,
            (cells, reductions, nodes + 1),  # we added a node
            metrics=search_metrics)
        return new_point
Example #2
0
def test_darts_zero_model():
    conf = common_init(config_filepath='confs/algos/darts.yaml')
    conf_search = conf['nas']['search']
    model_desc = conf_search['model_desc']

    model_desc_builder = ModelDescBuilder()
    model_desc = model_desc_builder.build(model_desc)
    m = Model(model_desc, False, True)
    y, aux = m(torch.rand((1, 3, 32, 32)))
    assert isinstance(y, torch.Tensor) and y.shape == (1, 10) and aux is None
Example #3
0
    def build_model_desc(self, model_desc_builder: ModelDescBuilder,
                         conf_model_desc: Config, reductions: int, cells: int,
                         nodes: int) -> ModelDesc:
        # reset macro params in copy of config
        conf_model_desc = copy.deepcopy(conf_model_desc)
        conf_model_desc['n_reductions'] = reductions
        conf_model_desc['n_cells'] = cells

        # create model desc for search using model config
        # we will build model without call to model_desc_builder for pre-training
        model_desc = model_desc_builder.build(conf_model_desc, template=None)

        return model_desc
Example #4
0
    def create_model(self,
                     conf_eval: Config,
                     model_desc_builder: ModelDescBuilder,
                     final_desc_filename=None,
                     full_desc_filename=None) -> nn.Module:

        assert model_desc_builder is not None, 'Default evaluater requires model_desc_builder'

        # region conf vars
        # if explicitly passed in then don't get from conf
        if not final_desc_filename:
            final_desc_filename = conf_eval['final_desc_filename']
            full_desc_filename = conf_eval['full_desc_filename']
        conf_model_desc = conf_eval['model_desc']
        # endregion

        # load model desc file to get template model
        template_model_desc = ModelDesc.load(final_desc_filename)
        model_desc = model_desc_builder.build(conf_model_desc,
                                              template=template_model_desc)

        # save desc for reference
        model_desc.save(full_desc_filename)

        model = self.model_from_desc(model_desc)

        logger.info({
            'model_factory':
            False,
            'cells_len':
            len(model.desc.cell_descs()),
            'init_node_ch':
            conf_model_desc['model_stems']['init_node_ch'],
            'n_cells':
            conf_model_desc['n_cells'],
            'n_reductions':
            conf_model_desc['n_reductions'],
            'n_nodes':
            conf_model_desc['cell']['n_nodes']
        })

        return model
Example #5
0
    def _add_node(self, model_desc: ModelDesc,
                  model_desc_builder: ModelDescBuilder) -> None:
        for ci, cell_desc in enumerate(model_desc.cell_descs()):
            reduction = (cell_desc.cell_type == CellType.Reduction)

            nodes = cell_desc.nodes()

            # petridish must seed with one node
            assert len(nodes) > 0
            # input/output channels for all nodes are same
            conv_params = nodes[0].conv_params

            # assign input IDs to nodes, s0 and s1 have IDs 0 and 1
            # however as we will be inserting new node before last one
            input_ids = list(range(len(nodes) + 1))
            assert len(input_ids) >= 2  # 2 stem inputs
            op_desc = OpDesc('petridish_reduction_op' if reduction else 'petridish_normal_op',
                                params={
                                    'conv': conv_params,
                                    # specify strides for each input, later we will
                                    # give this to each primitive
                                    '_strides':[2 if reduction and j < 2 else 1 \
                                            for j in input_ids],
                                }, in_len=len(input_ids), trainables=None, children=None)
            edge = EdgeDesc(op_desc, input_ids=input_ids)
            new_node = NodeDesc(edges=[edge], conv_params=conv_params)
            nodes.insert(len(nodes) - 1, new_node)

            # output shape of all nodes are same
            node_shapes = cell_desc.node_shapes
            new_node_shape = copy.deepcopy(node_shapes[-1])
            node_shapes.insert(len(node_shapes) - 1, new_node_shape)

            # post op needs rebuilding because number of inputs to it has changed so input/output channels may be different
            post_op_shape, post_op_desc = model_desc_builder.build_cell_post_op(
                cell_desc.stem_shapes, node_shapes, cell_desc.conf_cell, ci)
            cell_desc.reset_nodes(nodes, node_shapes, post_op_desc,
                                  post_op_shape)
Example #6
0
    def _train_dist(evaluater: Evaluater, conf_eval: Config,
                    model_desc_builder: ModelDescBuilder,
                    model_desc_filename: str, common_state) -> ConvexHullPoint:
        """Train given a model"""

        common.init_from(common_state)

        # region config vars
        conf_model_desc = conf_eval['model_desc']
        max_cells = conf_model_desc['n_cells']

        conf_checkpoint = conf_eval['checkpoint']
        resume = conf_eval['resume']

        conf_petridish = conf_eval['petridish']
        cell_count_scale = conf_petridish['cell_count_scale']
        #endregion

        #register ops as we are in different process now
        model_desc_builder.pre_build(conf_model_desc)

        model_filename = utils.append_to_filename(model_desc_filename,
                                                  '_model', '.pt')
        full_desc_filename = utils.append_to_filename(model_desc_filename,
                                                      '_full', '.yaml')
        metrics_filename = utils.append_to_filename(model_desc_filename,
                                                    '_metrics', '.yaml')
        model_stats_filename = utils.append_to_filename(
            model_desc_filename, '_model_stats', '.yaml')

        # create checkpoint for this specific model desc by changing the config
        checkpoint = None
        if conf_checkpoint is not None:
            conf_checkpoint['filename'] = model_filename.split(
                '.')[0] + '_checkpoint.pth'
            checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)

            if checkpoint is not None and resume:
                if 'metrics_stats' in checkpoint:
                    # return the output we had recorded in the checkpoint
                    convex_hull_point = checkpoint['metrics_stats']
                    return convex_hull_point

        # template model is what we used during the search
        template_model_desc = ModelDesc.load(model_desc_filename)

        # we first scale this model by number of cells, keeping reductions same as in search
        n_cells = math.ceil(
            len(template_model_desc.cell_descs()) * cell_count_scale)
        n_cells = min(n_cells, max_cells)

        conf_model_desc = copy.deepcopy(conf_model_desc)
        conf_model_desc['n_cells'] = n_cells
        conf_model_desc[
            'n_reductions'] = n_reductions = template_model_desc.cell_type_count(
                CellType.Reduction)

        model_desc = model_desc_builder.build(conf_model_desc,
                                              template=template_model_desc)
        # save desc for reference
        model_desc.save(full_desc_filename)

        model = evaluater.model_from_desc(model_desc)

        train_metrics = evaluater.train_model(conf_eval, model, checkpoint)
        train_metrics.save(metrics_filename)

        # get metrics_stats
        model_stats = nas_utils.get_model_stats(model)
        # save metrics_stats
        with open(model_stats_filename, 'w') as f:
            yaml.dump(model_stats, f)

        # save model
        if model_filename:
            model_filename = utils.full_path(model_filename)
            ml_utils.save_model(model, model_filename)
            # TODO: Causes logging error at random times. Commenting out as stop-gap fix.
            # logger.info({'model_save_path': model_filename})

        hull_point = ConvexHullPoint(
            JobStage.EVAL_TRAINED,
            0,
            0,
            model_desc,
            (n_cells, n_reductions, len(model_desc.cell_descs()[0].nodes())),
            metrics=train_metrics,
            model_stats=model_stats)

        if checkpoint:
            checkpoint.new()
            checkpoint['metrics_stats'] = hull_point
            checkpoint.commit()

        return hull_point
Example #7
0
 def model_desc_builder(self)->Optional[ModelDescBuilder]:
     return ModelDescBuilder() # default model desc builder puts nodes with no edges