예제 #1
0
    def get_model(cls, model_desc=None, pretrained_model_file=None):
        """Get model from model zoo.

        :param network_name: the name of network, eg. ResNetVariant.
        :type network_name: str or None.
        :param network_desc: the description of network.
        :type network_desc: str or None.
        :param pretrained_model_file: path of model.
        :type pretrained_model_file: str.
        :return: model.
        :rtype: model.

        """
        try:
            network = NetworkDesc(model_desc)
            model = network.to_model()
        except Exception as e:
            logging.error("Failed to get model, model_desc={}, msg={}".format(
                model_desc, str(e)))
            raise e
        logging.info("Model was created.")
        if zeus.is_torch_backend() and pretrained_model_file:
            model = cls._load_pretrained_model(model, pretrained_model_file)
        elif zeus.is_ms_backend() and pretrained_model_file:
            model = cls._load_pretrained_model(model, pretrained_model_file)
        return model
예제 #2
0
 def new_model(self):
     """Build new model."""
     net_desc = NetworkDesc(self.search_space)
     model_new = net_desc.to_model().cuda()
     for x, y in zip(model_new.arch_parameters(), self.model.arch_parameters()):
         x.detach().copy_(y.detach())
     return model_new
예제 #3
0
    def _init_model(self):
        """Initialize the model architecture for full train step.

        :return: train model
        :rtype: class
        """
        logging.info('Initializing model')
        if self.cfg.model_desc:
            logging.debug("model_desc: {}".format(self.cfg.model_desc))
            _file = FileOps.join_path(self.worker_path, "model_desc_{}.json".format(self._worker_id))
            with open(_file, "w") as f:
                json.dump(self.cfg.model_desc, f)
            if self.cfg.distributed:
                hvd.join()
            model_desc = self.cfg.model_desc
            net_desc = NetworkDesc(model_desc)
            model = net_desc.to_model()
            return model
        else:
            return None
예제 #4
0
    def get_model(cls, model_desc=None, pretrained_model_file=None, exclude_weight_prefix=None):
        """Get model from model zoo.

        :param network_name: the name of network, eg. ResNetVariant.
        :type network_name: str or None.
        :param network_desc: the description of network.
        :type network_desc: str or None.
        :param pretrained_model_file: path of model.
        :type pretrained_model_file: str.
        :return: model.
        :rtype: model.

        """
        model = None
        if model_desc is not None:
            try:
                is_deformation = False
                if 'deformation' in model_desc:
                    model_desc = {"type": model_desc.pop('deformation'), 'desc': model_desc,
                                  'weight_file': pretrained_model_file}
                    pretrained_model_file = None
                    is_deformation = True
                network = NetworkDesc(model_desc, is_deformation)
                model = network.to_model()
            except Exception as e:
                logging.error("Failed to get model, model_desc={}, msg={}".format(model_desc, str(e)))
                raise e
        logging.info("Model was created.")
        if not isinstance(model, Module):
            model = cls.to_module(model)
        if pretrained_model_file is not None:
            if exclude_weight_prefix:
                model.exclude_weight_prefix = exclude_weight_prefix
            model = cls._load_pretrained_model(model, pretrained_model_file)
        if model is None:
            raise ValueError("Failed to get mode, model is None.")
        return model