예제 #1
0
    def create_hooks(
        self,
        t2r_model,
        estimator,
    ):
        if not self._export_dir and not self._lagged_export_dir:
            return []
        self._export_generator.set_specification_from_model(t2r_model)
        warmup_requests_file = self._export_generator.create_warmup_requests_numpy(
            batch_sizes=self._batch_sizes_for_export,
            export_dir=estimator.model_dir)

        in_feature_spec = t2r_model.get_feature_specification_for_packing(
            mode=tf.estimator.ModeKeys.PREDICT)
        in_label_spec = t2r_model.get_label_specification_for_packing(
            mode=tf.estimator.ModeKeys.PREDICT)
        t2r_assets = t2r_pb2.T2RAssets()
        t2r_assets.feature_spec.CopyFrom(in_feature_spec.to_proto())
        t2r_assets.label_spec.CopyFrom(in_label_spec.to_proto())

        def _export_fn(export_dir, global_step):
            """The actual closure function creating the exported model and assets."""
            t2r_assets.global_step = global_step
            tmpdir = tempfile.mkdtemp()
            t2r_assets_filename = os.path.join(
                tmpdir, tensorspec_utils.T2R_ASSETS_FILENAME)
            tensorspec_utils.write_t2r_assets_to_file(t2r_assets,
                                                      t2r_assets_filename)
            res = estimator.export_saved_model(
                export_dir_base=export_dir,
                serving_input_receiver_fn=self._export_generator.
                create_serving_input_receiver_numpy_fn(),
                assets_extra={
                    'tf_serving_warmup_requests': warmup_requests_file,
                    tensorspec_utils.T2R_ASSETS_FILENAME: t2r_assets_filename
                })
            return res

        return [
            contrib_tpu.AsyncCheckpointSaverHook(
                save_secs=self._save_secs,
                checkpoint_dir=estimator.model_dir,
                listeners=[
                    checkpoint_hooks.LaggedCheckpointListener(
                        export_fn=_export_fn,
                        num_versions=self._num_versions,
                        export_dir=self._export_dir,
                        lagged_export_dir=self._lagged_export_dir)
                ])
        ]
 def create_hooks(
     self,
     t2r_model,
     estimator,
 ):
     self._export_generator.set_specification_from_model(t2r_model)
     return [
         contrib_tpu.AsyncCheckpointSaverHook(
             save_secs=self._save_secs,
             checkpoint_dir=estimator.model_dir,
             listeners=[
                 checkpoint_hooks.CheckpointExportListener(
                     export_fn=self._create_export_fn(
                         t2r_model, estimator, self._export_generator),
                     num_versions=self._num_versions,
                     export_dir=self._export_dir)
             ])
     ]