def __prepare_model(self, train_mode=False): """Prepare utilities for decoding.""" hparams = registry.hparams(self.params.hparams_set) hparams.problem = self.problem hparams.problem_hparams = self.problem.get_hparams(hparams) if self.params.hparams: tf.logging.info("Overriding hparams in %s with %s", self.params.hparams_set, self.params.hparams) hparams = hparams.parse(self.params.hparams) trainer_run_config = g2p_trainer_utils.create_run_config(hparams, self.params) if train_mode: exp_fn = g2p_trainer_utils.create_experiment_fn(self.params, self.problem) self.exp = exp_fn(trainer_run_config, hparams) decode_hp = decoding.decode_hparams(self.params.decode_hparams) estimator = trainer_lib.create_estimator( self.params.model_name, hparams, trainer_run_config, decode_hparams=decode_hp, use_tpu=False) return estimator, decode_hp, hparams
def __prepare_model(self): """Prepare utilities for decoding.""" hparams = trainer_lib.create_hparams( hparams_set=self.params.hparams_set, hparams_overrides_str=self.params.hparams) trainer_run_config = g2p_trainer_utils.create_run_config(hparams, self.params) exp_fn = g2p_trainer_utils.create_experiment_fn(self.params, self.problem) self.exp = exp_fn(trainer_run_config, hparams) decode_hp = decoding.decode_hparams(self.params.decode_hparams) decode_hp.add_hparam("shards", self.params.decode_shards) decode_hp.add_hparam("shard_id", self.params.worker_id) estimator = trainer_lib.create_estimator( self.params.model_name, hparams, trainer_run_config, decode_hparams=decode_hp, use_tpu=False) return estimator, decode_hp, hparams