コード例 #1
0
    def search_model_desc(self, conf_search:Config, model_desc:ModelDesc,
                          trainer_class:TArchTrainer, finalizers:Finalizers)\
                              ->Tuple[ModelDesc, Optional[Metrics]]:

        # if trainer is not specified for algos like random search we return same desc
        if trainer_class is None:
            return model_desc, None

        logger.pushd('arch_search')

        conf_trainer = conf_search['trainer']
        conf_loader = conf_search['loader']

        model = Model(model_desc, droppath=False, affine=False)

        # get data
        data_loaders = self.get_data(conf_loader)

        # search arch
        arch_trainer = trainer_class(conf_trainer, model, checkpoint=None)
        search_metrics = arch_trainer.fit(data_loaders)

        # finalize
        found_desc = self.finalize_model(model, finalizers)

        logger.popd()

        return found_desc, search_metrics
コード例 #2
0
    def train_model_desc(self, model_desc:ModelDesc, conf_train:Config)\
            ->Optional[ModelMetrics]:
        """Train given description"""

        # region conf vars
        conf_trainer = conf_train['trainer']
        conf_loader = conf_train['loader']
        trainer_title = conf_trainer['title']
        epochs = conf_trainer['epochs']
        drop_path_prob = conf_trainer['drop_path_prob']
        # endregion

        # if epochs ==0 then nothing to train, so save time
        if epochs <= 0:
            return None

        logger.pushd(trainer_title)

        model = Model(model_desc, droppath=drop_path_prob > 0.0, affine=True)

        # get data
        data_loaders = self.get_data(conf_loader)

        trainer = Trainer(conf_trainer, model, checkpoint=None)
        train_metrics = trainer.fit(data_loaders)

        logger.popd()

        return ModelMetrics(model, train_metrics)
コード例 #3
0
    def finalize_model(self,
                       model: Model,
                       to_cpu=True,
                       restore_device=True) -> ModelDesc:

        logger.pushd('finalize')

        # get config and train data loader
        # TODO: confirm this is correct in case you get silent bugs
        conf = get_conf()
        conf_loader = conf['nas']['search']['loader']
        train_dl, val_dl, test_dl = get_data(conf_loader)

        # wrap all cells in the model
        self._divnas_cells: Dict[int, Divnas_Cell] = {}
        for _, cell in enumerate(model.cells):
            divnas_cell = Divnas_Cell(cell)
            self._divnas_cells[id(cell)] = divnas_cell

        # go through all edges in the DAG and if they are of divop
        # type then set them to collect activations
        sigma = conf['nas']['search']['divnas']['sigma']
        for _, dcell in enumerate(self._divnas_cells.values()):
            dcell.collect_activations(DivOp, sigma)

        # now we need to run one evaluation epoch to collect activations
        # we do it on cpu otherwise we might run into memory issues
        # later we can redo the whole logic in pytorch itself
        # at the end of this each node in a cell will have the covariance
        # matrix of all incoming edges' ops
        model = model.cpu()
        model.eval()
        with torch.no_grad():
            for _ in range(1):
                for _, (x, _) in enumerate(train_dl):
                    _, _ = model(x), None
                    # now you can go through and update the
                    # node covariances in every cell
                    for dcell in self._divnas_cells.values():
                        dcell.update_covs()

        logger.popd()

        return super().finalize_model(model, to_cpu, restore_device)
コード例 #4
0
def test_darts_zero_model():
    conf = common_init(config_filepath='confs/darts_cifar.yaml')
    conf_search = conf['nas']['search']
    model_desc = conf_search['model_desc']

    macro_builder = MacroBuilder(model_desc, aux_tower=False)
    model_desc = macro_builder.build()
    m = Model(model_desc, False, True)
    y, aux = m(torch.rand((1, 3, 32, 32)))
    assert isinstance(y, torch.Tensor) and y.shape==(1,10) and aux is None
コード例 #5
0
def test_petridish_zero_model():
    conf = common_init(config_filepath='confs/petridish_cifar.yaml')
    conf_search = conf['nas']['search']
    model_desc = conf_search['model_desc']

    model_desc_builder = ModelDescBuilder()
    model_desc = model_desc_builder.build(model_desc)
    m = Model(model_desc, False, True)
    y, aux = m(torch.rand((1, 3, 32, 32)))
    assert isinstance(y, torch.Tensor) and y.shape == (1, 10) and aux is None
コード例 #6
0
    def finalize_model(self,
                       model: Model,
                       to_cpu=True,
                       restore_device=True) -> ModelDesc:
        # move model to CPU before finalize because each op will serialize
        # its parameters and we don't want copy of these parameters hanging on GPU
        original = model.device_type()
        if to_cpu:
            model.cpu()

        # finalize will create copy of state and this can overflow GPU RAM
        assert model.device_type() == 'cpu'

        cell_descs = self.finalize_cells(model)

        if restore_device:
            model.to(original, non_blocking=True)

        return ModelDesc(
            conf_model_desc=model.desc.conf_model_desc,
            model_stems=[op.finalize()[0] for op in model.model_stems],
            pool_op=model.pool_op.finalize()[0],
            cell_descs=cell_descs,
            aux_tower_descs=model.desc.aux_tower_descs,
            logits_op=model.logits_op.finalize()[0])
コード例 #7
0
    def _create_seed_jobs(self, conf_search:Config, model_desc_builder:ModelDescBuilder)->list:
        conf_model_desc = conf_search['model_desc']
        conf_seed_train = conf_search['seed_train']

        future_ids = [] # ray job IDs
        seed_model_stats = [] # seed model stats for visualization and debugging 
        macro_combinations = list(self.get_combinations(conf_search))
        for reductions, cells, nodes in macro_combinations:
            # if N R N R N R cannot be satisfied, ignore combination
            if cells < reductions * 2 + 1:
                continue

            # create seed model
            model_desc = self.build_model_desc(model_desc_builder,
                                               conf_model_desc,
                                               reductions, cells, nodes)

            hull_point = ConvexHullPoint(JobStage.SEED, 0, 0, model_desc,
                                         (cells, reductions, nodes))

            # pre-train the seed model
            future_id = SearcherPetridish.train_model_desc_dist.remote(self,
                conf_seed_train, hull_point, common.get_state())

            future_ids.append(future_id)

            # build a model so we can get its model stats
            temp_model = Model(model_desc, droppath=True, affine=True)
            seed_model_stats.append(nas_utils.get_model_stats(temp_model))
        
        # save the model stats in a plot and tsv file so we can
        # visualize the spread on the x-axis
        expdir = common.get_expdir()
        assert expdir
        plot_seed_model_stats(seed_model_stats, expdir)

        return future_ids
コード例 #8
0
def _get_alphas(model:Model)->Iterator[nn.Parameter]:
    return model.all_owned().param_by_kind('alphas')
コード例 #9
0
ファイル: model_size.py プロジェクト: wayne9qiu/archai
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from archai.nas.model_desc import ModelDesc
from archai.common.common import common_init
from archai.nas.model import Model
from archai.algos.petridish.petridish_model_desc_builder import PetridishModelBuilder

from archai.common.model_summary import summary

conf = common_init(config_filepath='confs/petridish_cifar.yaml',
                    param_args=['--common.experiment_name', 'petridish_run2_seed42_eval'])

conf_eval = conf['nas']['eval']
conf_model_desc   = conf_eval['model_desc']

conf_model_desc['n_cells'] = 14
template_model_desc = ModelDesc.load('$expdir/final_model_desc.yaml')

model_builder = PetridishModelBuilder()

model_desc = model_builder.build(conf_model_desc, template=template_model_desc)

mb = PetridishModelBuilder()
model = Model(model_desc, droppath=False, affine=False)

summary(model, [64, 3, 32, 32])


exit(0)
コード例 #10
0
 def model_from_desc(self, model_desc) -> Model:
     return Model(model_desc, droppath=True, affine=True)