def __init__(self, model_params=None, json_path=None): super().__init__() self.model_params = get_from_dicts(model_params, default_parameters) self.model_params = get_from_json(json_path, self.model_params) self._sanity_checks() logging.info("Model parameters : %s", self.model_params) self.input_type = self.model_params['input']['type'] self.model_dir = self.model_params['output']['save_model_dir'] self.config = get_tf_config(self.model_params) self.model = tf.estimator.Estimator( model_fn=self.model_fn, model_dir=self.model_dir, params=self.model_params, config=self.config)
def __init__(self, params=None, json_path=None): super().__init__() self.params = get_from_dicts(params, default_parameters) self.params = get_from_json(json_path, self.params) self._sanity_check() logging.info("Model parameters : %s", self.params) self.input_type = self.params['input']['type'] self.init_process_group() self.model = DistributedDataParallel( self.model_fn(self.params).to(get_device(self.params))) self.optimizer = get_optimizer(key=self.params['optimizer']['name'])( params=self.model.parameters(), **self.params['optimizer']['kwargs']) self.loss = get_loss_fn(key=self.params['loss']['name'])( **self.params['loss']['kwargs']) self.metric = get_metric_fn(key=self.params['output']['metric'])
def __init__(self, params=None, json_path=None): super().__init__() self.params = get_from_dicts(params, default_parameters) self.params = get_from_json(json_path, self.params) self._sanity_check() Path(self.params["output"]["save_model_dir"]).expanduser().resolve().mkdir( parents=True, exist_ok=True ) logging.info("Model parameters : %s", self.params) self.input_type = self.params["input"]["type"] self.init_process_group() self.model = DistributedDataParallel(self.model_fn(self.params).to(get_device(self.params))) self.optimizer = get_optimizer(key=self.params["optimizer"]["name"])( params=self.model.parameters(), **self.params["optimizer"]["kwargs"] ) self.loss = get_loss_fn(key=self.params["loss"]["name"])(**self.params["loss"]["kwargs"]) self.metric = get_metric_fn(key=self.params["output"]["metric"])
def test_merge_json(output_json_filepath): default_params = {"learning_rate": 0.08, "embedding_size": 256} params = get_from_json(output_json_filepath, default_params) assert params['learning_rate'] == 0.05 assert params['embedding_size'] == 256