Exemplo n.º 1
0
def create_model(conf_eval: Config) -> nn.Module:
    # region conf vars
    dataset_name = conf_eval['loader']['dataset']['name']
    final_desc_filename = conf_eval['final_desc_filename']
    final_model_factory = conf_eval['final_model_factory']
    full_desc_filename = conf_eval['full_desc_filename']
    conf_model_desc = conf_eval['model_desc']
    # endregion

    if final_model_factory:
        splitted = final_model_factory.rsplit('.', 1)
        function_name = splitted[-1]

        if len(splitted) > 1:
            module_name = splitted[0]
        else:
            module_name = _default_module_name(dataset_name, function_name)

        module = importlib.import_module(
            module_name) if module_name else sys.modules[__name__]
        function = getattr(module, function_name)
        model = function()

        logger.info({
            'model_factory': True,
            'module_name': module_name,
            'function_name': function_name,
            'params': ml_utils.param_size(model)
        })
    else:
        # load model desc file to get template model
        template_model_desc = ModelDesc.load(final_desc_filename)

        model = nas_utils.model_from_conf(
            full_desc_filename,
            conf_model_desc,
            affine=True,
            droppath=True,
            template_model_desc=template_model_desc)

        logger.info({
            'model_factory': False,
            'cells_len': len(model.desc.cell_descs()),
            'init_node_ch': conf_model_desc['init_node_ch'],
            'n_cells': conf_model_desc['n_cells'],
            'n_reductions': conf_model_desc['n_reductions'],
            'n_nodes': conf_model_desc['n_nodes']
        })

    return model
Exemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser(description='Visualize model description')
    parser.add_argument('-f',
                        '--model-desc-file',
                        type=str,
                        default='models/final_model_desc5.yaml',
                        help='Model desc file')
    args, extra_args = parser.parse_known_args()

    model_desc_filepath = utils.full_path(args.model_desc_file)
    model_desc = ModelDesc.load(model_desc_filepath)

    out_file = pathlib.Path(model_desc_filepath).with_suffix('')

    draw_model_desc(model_desc, str(out_file))
Exemplo n.º 3
0
    def create_model(self,
                     conf_eval: Config,
                     model_desc_builder: ModelDescBuilder,
                     final_desc_filename=None,
                     full_desc_filename=None) -> nn.Module:

        assert model_desc_builder is not None, 'Default evaluater requires model_desc_builder'

        # region conf vars
        # if explicitly passed in then don't get from conf
        if not final_desc_filename:
            final_desc_filename = conf_eval['final_desc_filename']
            full_desc_filename = conf_eval['full_desc_filename']
        conf_model_desc = conf_eval['model_desc']
        # endregion

        # load model desc file to get template model
        template_model_desc = ModelDesc.load(final_desc_filename)
        model_desc = model_desc_builder.build(conf_model_desc,
                                              template=template_model_desc)

        # save desc for reference
        model_desc.save(full_desc_filename)

        model = self.model_from_desc(model_desc)

        logger.info({
            'model_factory':
            False,
            'cells_len':
            len(model.desc.cell_descs()),
            'init_node_ch':
            conf_model_desc['model_stems']['init_node_ch'],
            'n_cells':
            conf_model_desc['n_cells'],
            'n_reductions':
            conf_model_desc['n_reductions'],
            'n_nodes':
            conf_model_desc['cell']['n_nodes']
        })

        return model
Exemplo n.º 4
0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from archai.nas.model_desc import ModelDesc
from archai.common.common import common_init
from archai.nas.model import Model
from archai.algos.petridish.petridish_model_desc_builder import PetridishModelBuilder

from archai.common.model_summary import summary

conf = common_init(config_filepath='confs/petridish_cifar.yaml',
                    param_args=['--common.experiment_name', 'petridish_run2_seed42_eval'])

conf_eval = conf['nas']['eval']
conf_model_desc   = conf_eval['model_desc']

conf_model_desc['n_cells'] = 14
template_model_desc = ModelDesc.load('$expdir/final_model_desc.yaml')

model_builder = PetridishModelBuilder()

model_desc = model_builder.build(conf_model_desc, template=template_model_desc)

mb = PetridishModelBuilder()
model = Model(model_desc, droppath=False, affine=False)

summary(model, [64, 3, 32, 32])


exit(0)
Exemplo n.º 5
0
    def _train_dist(evaluater: Evaluater, conf_eval: Config,
                    model_desc_builder: ModelDescBuilder,
                    model_desc_filename: str, common_state) -> ConvexHullPoint:
        """Train given a model"""

        common.init_from(common_state)

        # region config vars
        conf_model_desc = conf_eval['model_desc']
        max_cells = conf_model_desc['n_cells']

        conf_checkpoint = conf_eval['checkpoint']
        resume = conf_eval['resume']

        conf_petridish = conf_eval['petridish']
        cell_count_scale = conf_petridish['cell_count_scale']
        #endregion

        #register ops as we are in different process now
        model_desc_builder.pre_build(conf_model_desc)

        model_filename = utils.append_to_filename(model_desc_filename,
                                                  '_model', '.pt')
        full_desc_filename = utils.append_to_filename(model_desc_filename,
                                                      '_full', '.yaml')
        metrics_filename = utils.append_to_filename(model_desc_filename,
                                                    '_metrics', '.yaml')
        model_stats_filename = utils.append_to_filename(
            model_desc_filename, '_model_stats', '.yaml')

        # create checkpoint for this specific model desc by changing the config
        checkpoint = None
        if conf_checkpoint is not None:
            conf_checkpoint['filename'] = model_filename.split(
                '.')[0] + '_checkpoint.pth'
            checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)

            if checkpoint is not None and resume:
                if 'metrics_stats' in checkpoint:
                    # return the output we had recorded in the checkpoint
                    convex_hull_point = checkpoint['metrics_stats']
                    return convex_hull_point

        # template model is what we used during the search
        template_model_desc = ModelDesc.load(model_desc_filename)

        # we first scale this model by number of cells, keeping reductions same as in search
        n_cells = math.ceil(
            len(template_model_desc.cell_descs()) * cell_count_scale)
        n_cells = min(n_cells, max_cells)

        conf_model_desc = copy.deepcopy(conf_model_desc)
        conf_model_desc['n_cells'] = n_cells
        conf_model_desc[
            'n_reductions'] = n_reductions = template_model_desc.cell_type_count(
                CellType.Reduction)

        model_desc = model_desc_builder.build(conf_model_desc,
                                              template=template_model_desc)
        # save desc for reference
        model_desc.save(full_desc_filename)

        model = evaluater.model_from_desc(model_desc)

        train_metrics = evaluater.train_model(conf_eval, model, checkpoint)
        train_metrics.save(metrics_filename)

        # get metrics_stats
        model_stats = nas_utils.get_model_stats(model)
        # save metrics_stats
        with open(model_stats_filename, 'w') as f:
            yaml.dump(model_stats, f)

        # save model
        if model_filename:
            model_filename = utils.full_path(model_filename)
            ml_utils.save_model(model, model_filename)
            # TODO: Causes logging error at random times. Commenting out as stop-gap fix.
            # logger.info({'model_save_path': model_filename})

        hull_point = ConvexHullPoint(
            JobStage.EVAL_TRAINED,
            0,
            0,
            model_desc,
            (n_cells, n_reductions, len(model_desc.cell_descs()[0].nodes())),
            metrics=train_metrics,
            model_stats=model_stats)

        if checkpoint:
            checkpoint.new()
            checkpoint['metrics_stats'] = hull_point
            checkpoint.commit()

        return hull_point
Exemplo n.º 6
0
from archai.nas.model_desc import ModelDesc
from archai.common.common import common_init
from archai.nas.model import Model
from archai.petridish.petridish_micro_builder import PetridishMicroBuilder
from archai.nas.nas_utils import create_macro_desc

from archai.common.model_summary import summary

conf = common_init(
    config_filepath='confs/petridish_cifar.yaml',
    param_args=['--common.experiment_name', 'petridish_run2_seed42_eval'])

conf_eval = conf['nas']['eval']
conf_model_desc = conf_eval['model_desc']

conf_model_desc['n_cells'] = 14
template_model_desc = ModelDesc.load('final_model_desc.yaml')
model_desc = create_macro_desc(conf_model_desc, True, template_model_desc)

mb = PetridishMicroBuilder()
mb.register_ops()
model = Model(model_desc, droppath=False, affine=False)
#model.cuda()
summary(model, [64, 3, 32, 32])

exit(0)