def get_model(self, stage: str) -> _Model: model_params = self._config["model_params"] model = MODELS.get_from_params(**model_params) model = self._preprocess_model_for_stage(stage, model) model = self._postprocess_model_for_stage(stage, model) return model
def _get_model(**params): key_value_flag = params.pop("_key_value", False) if key_value_flag: model = {} for key, params_ in params.items(): model[key] = ConfigExperiment._get_model(**params_) else: model = MODELS.get_from_params(**params) return model
def _get_model(**params): key_value_flag = params.pop("_key_value", False) if key_value_flag: model = {} for model_key, model_params in params.items(): model[model_key] = ConfigExperiment._get_model( # noqa: WPS437 **model_params) model = nn.ModuleDict(model) else: model = MODELS.get_from_params(**params) return model
def get_model(self, stage: str) -> _Model: model_params = self._config["model_params"] fp16 = model_params.pop("fp16", False) model = MODELS.get_from_params(**model_params) if fp16: utils.assert_fp16_available() model = Fp16Wrap(model) model = self._preprocess_model_for_stage(stage, model) model = self._postprocess_model_for_stage(stage, model) return model
def _get_model(**params): key_value_flag = params.pop("_key_value", False) # todo registry init dc_init_flag = params.pop("_dcgan_initialize", False) if key_value_flag: model = {} for key, params_ in params.items(): model[key] = Experiment._get_model(**params_) model = nn.ModuleDict(model) else: model = MODELS.get_from_params(**params) if dc_init_flag: model.apply(dcgan_weights_init) return model