def _get_estimator(self, portion):
        assert portion in [
            'featurizer', 'target'
        ], "Can only split model into featurizer and target."
        config = self._get_estimator_config()

        fn = get_separate_model_fns(
            target_model_fn=self._target_model
            if portion == 'target' else None,
            predict_op=self._predict_op,
            predict_proba_op=self._predict_proba_op,
            build_target_model=self.input_pipeline.target_dim is not None,
            encoder=self.input_pipeline.text_encoder,
            target_dim=self.input_pipeline.target_dim
            if portion == 'target' else None,
            label_encoder=self.input_pipeline.label_encoder
            if portion == 'target' else None,
            saver=self.saver,
            portion=portion,
            build_attn=not isinstance(self.input_pipeline, ComparisonPipeline))

        estimator = tf.estimator.Estimator(model_dir=self.estimator_dir,
                                           model_fn=fn,
                                           config=config,
                                           params=self.config)

        if hasattr(self, 'predict_hooks') and portion == 'featurizer':
            for hook in self.predict_hooks:
                hook.need_to_refresh = True
        elif not hasattr(self, 'predict_hooks'):
            feat_hook = InitializeHook(self.saver, model_portion='featurizer')
            target_hook = InitializeHook(self.saver, model_portion='target')
            self.predict_hooks = PredictHook(feat_hook, target_hook)
        return estimator
Exemple #2
0
    def get_estimator(self, force_build_lm=False, build_explain=False):
        build_lm = force_build_lm or self.config.lm_loss_coef > 0.0
        config = self._get_estimator_config()
        model_fn = get_model_fn(
            target_model_fn=self._target_model,
            pre_target_model_hook=self._pre_target_model_hook,
            predict_op=self._predict_op,
            predict_proba_op=self._predict_proba_op,
            build_target_model=self.input_pipeline.target_dim is not None,
            lm_type=self.config.lm_type if build_lm else None,
            encoder=self.input_pipeline.text_encoder,
            target_dim=self.input_pipeline.target_dim,
            label_encoder=self.input_pipeline.label_encoder,
            build_explain=build_explain,
            n_replicas=max(1, len(self.resolved_gpus)))

        hooks = [InitializeHook(self.saver)]
        est = tf.estimator.Estimator(
            model_dir=self.estimator_dir,
            model_fn=model_fn,
            config=config,
            params=self.config,
        )

        return est, hooks
Exemple #3
0
    def get_separate_estimators(self, force_build_lm = False):
        fns = get_separate_model_fns(
            target_model_fn=self._target_model, 
            predict_op=self._predict_op,
            predict_proba_op=self._predict_proba_op,
            build_target_model=self.input_pipeline.target_dim is not None,
            build_lm=force_build_lm or self.config.lm_loss_coef > 0.0,
            encoder=self.input_pipeline.text_encoder,
            target_dim=self.input_pipeline.target_dim,
            label_encoder=self.input_pipeline.label_encoder,
            saver=self.saver
        )

        featurizer_est = tf.estimator.Estimator(
            model_dir=self.estimator_dir,
            model_fn=fns['featurizer_model_fn'],
            config=config,
            params=self.config
        )

        target_est = tf.estimator.Estimator(
            model_dir=self.estimator_dir,
            model_fn=fns['target_model_fn'],
            config=config,
            params=self.config
        )

        hooks = [InitializeHook(self.saver)]

        return featurizer_est, target_est, hooks
Exemple #4
0
    def get_estimator(
        self, force_build_lm=False, build_explain=False, context_dim=None
    ):
        config = self._get_estimator_config()

        model_fn = get_model_fn(
            target_model_fn=self._target_model,
            predict_op=self._predict_op,
            predict_proba_op=self._predict_proba_op,
            build_target_model=self.input_pipeline.target_dim is not None,
            build_lm=force_build_lm or self.config.lm_loss_coef > 0.0,
            encoder=self.input_pipeline.text_encoder,
            target_dim=self.input_pipeline.target_dim,
            label_encoder=self.input_pipeline.label_encoder,
            saver=self.saver,
            build_explain=build_explain,
            context_dim=context_dim or self.input_pipeline.config.context_dim,
        )
        hooks = [InitializeHook(self.saver)]
        est = tf.estimator.Estimator(
            model_dir=self.estimator_dir,
            model_fn=model_fn,
            config=config,
            params=self.config,
        )

        return est, hooks
Exemple #5
0
    def get_estimator(self, force_build_lm=False):
        conf = tf.ConfigProto(
            allow_soft_placement=self.config.soft_device_placement,
            log_device_placement=self.config.log_device_placement,
        )
        conf.gpu_options.per_process_gpu_memory_fraction = (
            self.config.per_process_gpu_memory_fraction)
        distribute_strategy = self._distribute_strategy(
            self.config.visible_gpus)
        config = tf.estimator.RunConfig(
            tf_random_seed=self.config.seed,
            save_summary_steps=self.config.val_interval,
            save_checkpoints_secs=None,
            save_checkpoints_steps=None,
            # disable auto summaries
            session_config=conf,
            log_step_count_steps=100,
            train_distribute=distribute_strategy,
            keep_checkpoint_max=1)

        model_fn = get_model_fn(
            target_model_fn=self._target_model,
            predict_op=self._predict_op,
            predict_proba_op=self._predict_proba_op,
            build_target_model=self.input_pipeline.target_dim is not None,
            build_lm=force_build_lm or self.config.lm_loss_coef > 0.0,
            encoder=self.input_pipeline.text_encoder,
            target_dim=self.input_pipeline.target_dim,
            label_encoder=self.input_pipeline.label_encoder,
            saver=self.saver)
        hooks = [InitializeHook(self.saver)]
        est = tf.estimator.Estimator(model_dir=self.estimator_dir,
                                     model_fn=model_fn,
                                     config=config,
                                     params=self.config)
        return est, hooks