Beispiel #1
0
    def finalize_model(self,
                       model: Model,
                       to_cpu=True,
                       restore_device=True) -> ModelDesc:
        # move model to CPU before finalize because each op will serialize
        # its parameters and we don't want copy of these parameters hanging on GPU
        original = model.device_type()
        if to_cpu:
            model.cpu()

        # finalize will create copy of state and this can overflow GPU RAM
        assert model.device_type() == 'cpu'

        cell_descs = self.finalize_cells(model)

        if restore_device:
            model.to(original, non_blocking=True)

        return ModelDesc(
            conf_model_desc=model.desc.conf_model_desc,
            model_stems=[op.finalize()[0] for op in model.model_stems],
            pool_op=model.pool_op.finalize()[0],
            cell_descs=cell_descs,
            aux_tower_descs=model.desc.aux_tower_descs,
            logits_op=model.logits_op.finalize()[0])
Beispiel #2
0
    def finalize_model(self,
                       model: Model,
                       to_cpu=True,
                       restore_device=True) -> ModelDesc:

        logger.pushd('finalize')

        # get config and train data loader
        # TODO: confirm this is correct in case you get silent bugs
        conf = get_conf()
        conf_loader = conf['nas']['search']['loader']
        train_dl, val_dl, test_dl = get_data(conf_loader)

        # wrap all cells in the model
        self._divnas_cells: Dict[int, Divnas_Cell] = {}
        for _, cell in enumerate(model.cells):
            divnas_cell = Divnas_Cell(cell)
            self._divnas_cells[id(cell)] = divnas_cell

        # go through all edges in the DAG and if they are of divop
        # type then set them to collect activations
        sigma = conf['nas']['search']['divnas']['sigma']
        for _, dcell in enumerate(self._divnas_cells.values()):
            dcell.collect_activations(DivOp, sigma)

        # now we need to run one evaluation epoch to collect activations
        # we do it on cpu otherwise we might run into memory issues
        # later we can redo the whole logic in pytorch itself
        # at the end of this each node in a cell will have the covariance
        # matrix of all incoming edges' ops
        model = model.cpu()
        model.eval()
        with torch.no_grad():
            for _ in range(1):
                for _, (x, _) in enumerate(train_dl):
                    _, _ = model(x), None
                    # now you can go through and update the
                    # node covariances in every cell
                    for dcell in self._divnas_cells.values():
                        dcell.update_covs()

        logger.popd()

        return super().finalize_model(model, to_cpu, restore_device)