def _get_estimator(self, portion): assert portion in [ 'featurizer', 'target' ], "Can only split model into featurizer and target." config = self._get_estimator_config() fn = get_separate_model_fns( target_model_fn=self._target_model if portion == 'target' else None, predict_op=self._predict_op, predict_proba_op=self._predict_proba_op, build_target_model=self.input_pipeline.target_dim is not None, encoder=self.input_pipeline.text_encoder, target_dim=self.input_pipeline.target_dim if portion == 'target' else None, label_encoder=self.input_pipeline.label_encoder if portion == 'target' else None, saver=self.saver, portion=portion, build_attn=not isinstance(self.input_pipeline, ComparisonPipeline)) estimator = tf.estimator.Estimator(model_dir=self.estimator_dir, model_fn=fn, config=config, params=self.config) if hasattr(self, 'predict_hooks') and portion == 'featurizer': for hook in self.predict_hooks: hook.need_to_refresh = True elif not hasattr(self, 'predict_hooks'): feat_hook = InitializeHook(self.saver, model_portion='featurizer') target_hook = InitializeHook(self.saver, model_portion='target') self.predict_hooks = PredictHook(feat_hook, target_hook) return estimator
def get_estimator(self, force_build_lm=False, build_explain=False): build_lm = force_build_lm or self.config.lm_loss_coef > 0.0 config = self._get_estimator_config() model_fn = get_model_fn( target_model_fn=self._target_model, pre_target_model_hook=self._pre_target_model_hook, predict_op=self._predict_op, predict_proba_op=self._predict_proba_op, build_target_model=self.input_pipeline.target_dim is not None, lm_type=self.config.lm_type if build_lm else None, encoder=self.input_pipeline.text_encoder, target_dim=self.input_pipeline.target_dim, label_encoder=self.input_pipeline.label_encoder, build_explain=build_explain, n_replicas=max(1, len(self.resolved_gpus))) hooks = [InitializeHook(self.saver)] est = tf.estimator.Estimator( model_dir=self.estimator_dir, model_fn=model_fn, config=config, params=self.config, ) return est, hooks
def get_separate_estimators(self, force_build_lm = False): fns = get_separate_model_fns( target_model_fn=self._target_model, predict_op=self._predict_op, predict_proba_op=self._predict_proba_op, build_target_model=self.input_pipeline.target_dim is not None, build_lm=force_build_lm or self.config.lm_loss_coef > 0.0, encoder=self.input_pipeline.text_encoder, target_dim=self.input_pipeline.target_dim, label_encoder=self.input_pipeline.label_encoder, saver=self.saver ) featurizer_est = tf.estimator.Estimator( model_dir=self.estimator_dir, model_fn=fns['featurizer_model_fn'], config=config, params=self.config ) target_est = tf.estimator.Estimator( model_dir=self.estimator_dir, model_fn=fns['target_model_fn'], config=config, params=self.config ) hooks = [InitializeHook(self.saver)] return featurizer_est, target_est, hooks
def get_estimator( self, force_build_lm=False, build_explain=False, context_dim=None ): config = self._get_estimator_config() model_fn = get_model_fn( target_model_fn=self._target_model, predict_op=self._predict_op, predict_proba_op=self._predict_proba_op, build_target_model=self.input_pipeline.target_dim is not None, build_lm=force_build_lm or self.config.lm_loss_coef > 0.0, encoder=self.input_pipeline.text_encoder, target_dim=self.input_pipeline.target_dim, label_encoder=self.input_pipeline.label_encoder, saver=self.saver, build_explain=build_explain, context_dim=context_dim or self.input_pipeline.config.context_dim, ) hooks = [InitializeHook(self.saver)] est = tf.estimator.Estimator( model_dir=self.estimator_dir, model_fn=model_fn, config=config, params=self.config, ) return est, hooks
def get_estimator(self, force_build_lm=False): conf = tf.ConfigProto( allow_soft_placement=self.config.soft_device_placement, log_device_placement=self.config.log_device_placement, ) conf.gpu_options.per_process_gpu_memory_fraction = ( self.config.per_process_gpu_memory_fraction) distribute_strategy = self._distribute_strategy( self.config.visible_gpus) config = tf.estimator.RunConfig( tf_random_seed=self.config.seed, save_summary_steps=self.config.val_interval, save_checkpoints_secs=None, save_checkpoints_steps=None, # disable auto summaries session_config=conf, log_step_count_steps=100, train_distribute=distribute_strategy, keep_checkpoint_max=1) model_fn = get_model_fn( target_model_fn=self._target_model, predict_op=self._predict_op, predict_proba_op=self._predict_proba_op, build_target_model=self.input_pipeline.target_dim is not None, build_lm=force_build_lm or self.config.lm_loss_coef > 0.0, encoder=self.input_pipeline.text_encoder, target_dim=self.input_pipeline.target_dim, label_encoder=self.input_pipeline.label_encoder, saver=self.saver) hooks = [InitializeHook(self.saver)] est = tf.estimator.Estimator(model_dir=self.estimator_dir, model_fn=model_fn, config=config, params=self.config) return est, hooks