def test_export_tpu_savedmodel_export_to_cpu_false(self):
        # Test that when `export_to_cpu` is `False`, CPU metagraph is not exported.
        tmpdir = tempfile.mkdtemp()

        model_fn = get_model_fn(export_tpu_tensor=True, export_cpu_tensor=True)
        run_config = create_run_config(iterations_per_loop=4)

        def _input_fn(params):
            return dummy_input_fn(params['batch_size'])

        est = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                         config=run_config,
                                         train_batch_size=16,
                                         export_to_tpu=True,
                                         export_to_cpu=False)
        est.train(_input_fn, steps=1)

        export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('export_no_tpu'))
        export_dir = est.export_saved_model(export_dir_base,
                                            self._serving_input_receiver_fn)
        saved_model = loader_impl.parse_saved_model(export_dir)
        self.assertLen(saved_model.meta_graphs, 1)
        tags = set(saved_model.meta_graphs[0].meta_info_def.tags)
        self.assertEqual(tags, set([tag_constants.SERVING, tag_constants.TPU]))

        # Clean up.
        gfile.DeleteRecursively(tmpdir)
    def test_export_tpu_savedmodel_export_to_tpu_false(self):
        # Test that when `export_to_tpu` is `False`, TPU metagraph is not exported.
        tmpdir = tempfile.mkdtemp()

        model_fn = get_model_fn(export_tpu_tensor=True, export_cpu_tensor=True)
        run_config = create_run_config(iterations_per_loop=4)

        def _input_fn(params):
            return dummy_input_fn(params['batch_size'])

        est = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                         config=run_config,
                                         train_batch_size=16,
                                         export_to_tpu=False)
        est.train(_input_fn, steps=1)

        export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('export_no_tpu'))
        export_dir = est.export_saved_model(export_dir_base,
                                            self._serving_input_receiver_fn)
        with ops.Graph().as_default() as graph:
            with session.Session(graph=graph) as sess:
                with self.assertRaisesRegex(
                        RuntimeError,
                        'MetaGraphDef associated with tags \'serve\', \'tpu\' could not be '
                        'found in SavedModel.'):
                    loader.load(sess,
                                [tag_constants.SERVING, tag_constants.TPU],
                                export_dir)
                loader.load(sess, [tag_constants.SERVING], export_dir)

        # Clean up.
        gfile.DeleteRecursively(tmpdir)
    def test_export_tpu_savedmodel_e2e(self):
        tmpdir = tempfile.mkdtemp()

        def _input_fn(params):
            return dummy_input_fn(params['batch_size'])

        model_fn = get_model_fn_v2()
        run_config = create_run_config(iterations_per_loop=4)
        est = tpu_estimator.TPUEstimator(
            model_fn=model_fn,
            config=run_config,
            train_batch_size=16,
            export_to_tpu=True,
            export_saved_model_api_version=tpu_estimator.
            ExportSavedModelApiVersion.V2)
        est.train(_input_fn, steps=1)

        # Perform the export.
        export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('export'))
        export_dir = est.export_saved_model(export_dir_base,
                                            self._serving_input_receiver_fn)

        self._validate_export(export_dir_base, export_dir)

        # Clean up.
        gfile.DeleteRecursively(tmpdir)
Exemple #4
0
    def test_evaluate_mode(self):
        input_fn_call_count = [0]
        eval_batch_size = 128
        run_config = create_run_config(
            iterations_per_loop=4,
            num_cores_per_replica=2,
            num_shards=1,
            input_partition_dims=[[1, 2], None],
            per_host_input_for_training=(
                tpu_config.InputPipelineConfig.PER_HOST_V2))

        def _input_fn(params):
            input_fn_call_count[0] += 1
            return dummy_input_fn_with_dataset(batch_size=params['batch_size'],
                                               fea_len=2)

        est = tpu_estimator.TPUEstimator(
            model_fn=model_fn_global_step_incrementer,
            config=run_config,
            train_batch_size=128,
            eval_batch_size=eval_batch_size)

        self.assertEqual(0, input_fn_call_count[0])
        est.evaluate(_input_fn, steps=1)
        self.assertEqual(1, input_fn_call_count[0])
Exemple #5
0
  def test_complete_flow_with_eval_on_tpu(self):
    # Choose the train_batch_size divisible by 2 and 8 (common shards in test
    # env) and batch_size for eval and predict prime number.
    train_batch_size = 16
    eval_batch_size = 8
    predict_batch_size = 8

    run_config = create_run_config(iterations_per_loop=4)
    num_shards = run_config.tpu_config.num_shards

    (expected_batch_size_for_model_fn, expected_batch_size_for_input_fn,
     expected_called_count_for_input_fn) = (
         self._generate_expected_batch_size_and_called_count(
             num_shards,
             train_batch_size,
             eval_batch_size,
             predict_batch_size,
             train_sharding_policy=_PER_HOST,
             eval_sharding_policy=_PER_HOST,
             predict_sharding_policy=_PER_HOST))

    est = tpu_estimator.TPUEstimator(
        model_fn=self._make_model_fn(
            expected_batch_size_for_model_fn, use_tpu_estimator_spec=True),
        config=run_config,
        train_batch_size=train_batch_size,
        eval_batch_size=eval_batch_size,
        predict_batch_size=predict_batch_size)

    # TRAIN
    # learn y = x
    # Note: Gradients are all zero. Just testing execution.
    train_input_fn = self._make_input_fn(mode=_TRAIN, repeat=True)
    est.train(train_input_fn, steps=7)

    # EVALUTE
    eval_input_fn = self._make_input_fn(mode=_EVAL, repeat=False)
    scores = est.evaluate(eval_input_fn, steps=2)
    self.assertEqual(7, scores['global_step'])
    self.assertGreater(0.1, scores['absolute_error'])

    # PREDICT
    predict_input_fn = self._make_input_fn(mode=_PREDICT, take=2)
    predictions = [x['predictions'] for x in est.predict(predict_input_fn)]
    self.assertAllClose(
        self._data[:predict_batch_size * 2], predictions, atol=0.01)

    # Verify all input_fn invoke recorded metadata.
    self.assertInputFnCalledCountAndBatch(
        expected_called_count_for_input_fn, expected_batch_size_for_input_fn)

    # EXPORT
    feature_spec = {'x': tf.io.FixedLenFeature([1], tf.float32)}
    serving_input_receiver_fn = (
        export.build_parsing_serving_input_receiver_fn(feature_spec))
    with self.export_mode():
      export_dir = est.export_saved_model(
          tempfile.mkdtemp(dir=self.get_temp_dir()), serving_input_receiver_fn)
    self.assertTrue(tf.gfile.Exists(export_dir))
    self._test_identity_savedmodel(export_dir)
 def test_error_out_if_steps_is_invalid(self):
     with self.assertRaisesRegex(ValueError, 'must be positive'):
         run_config = create_run_config(iterations_per_loop=2)
         est = tpu_estimator.TPUEstimator(
             model_fn=self._model_fn_with_eval_dict,
             config=run_config,
             train_batch_size=16,
             eval_batch_size=16,
             use_tpu=True)
         est.evaluate(self._create_input_fn(), steps=-321)
    def test_eval_metrics_with_dict(self):
        run_config = create_run_config(iterations_per_loop=2)
        est = tpu_estimator.TPUEstimator(
            model_fn=self._model_fn_with_eval_dict,
            config=run_config,
            train_batch_size=16,
            eval_batch_size=16)

        est.train(self._create_input_fn(), steps=1)
        est.evaluate(self._create_input_fn(), steps=1)
Exemple #8
0
  def test_eval_metrics_with_tensor_list(self):
    run_config = create_run_config(
        iterations_per_loop=2, num_shards=1, num_cores_per_replica=2)
    est = tpu_estimator.TPUEstimator(
        model_fn=self._model_fn_with_eval_tensor_list,
        config=run_config,
        train_batch_size=16,
        eval_batch_size=16)

    est.train(self._create_input_fn(), steps=1)
    est.evaluate(self._create_input_fn(), steps=1)
Exemple #9
0
 def test_fail_model_parallelism_for_per_core_input(self):
   run_config = create_run_config(
       iterations_per_loop=4,
       num_shards=1,
       num_cores_per_replica=2,
       per_host_input_for_training=False)
   with self.assertRaisesRegex(ValueError, 'Model parallelism only supports'):
     tpu_estimator.TPUEstimator(
         model_fn=model_fn_global_step_incrementer,
         config=run_config,
         train_batch_size=128)
Exemple #10
0
  def test_fail_with_wrong_num_shards(self):
    run_config = create_run_config(
        iterations_per_loop=2, num_shards=2, num_cores_per_replica=2)
    est = tpu_estimator.TPUEstimator(
        model_fn=self._model_fn_with_eval_tensor_list,
        config=run_config,
        train_batch_size=16,
        eval_batch_size=16)

    with self.assertRaisesRegex(ValueError, 'num_shards is not set correctly'):
      est.train(self._create_input_fn(), steps=1)
    def test_eval_metrics_ops_tpu_training(self):
        run_config = create_run_config(iterations_per_loop=2)
        est = tpu_estimator.TPUEstimator(
            model_fn=self._model_fn_with_eval_metric_ops,
            config=run_config,
            train_batch_size=16,
            eval_batch_size=16,
            use_tpu=True,
            eval_on_tpu=False)

        est.train(self._create_input_fn(), steps=1)
        est.evaluate(self._create_input_fn(), steps=1)
    def test_eval_batch_size_with_non_divisible_num_shards_broadcast_mode(
            self):
        run_config = create_run_config(iterations_per_loop=2,
                                       per_host_input_for_training=tpu_config.
                                       InputPipelineConfig.BROADCAST)
        est = tpu_estimator.TPUEstimator(
            model_fn=self._model_fn_with_eval_tensor_list,
            config=run_config,
            train_batch_size=7,
            eval_batch_size=7)

        est.train(self._create_input_fn(), steps=1)
        est.evaluate(self._create_input_fn(), steps=1)
Exemple #13
0
  def test_export_tpu_savedmodel_e2e(self, export_tpu_tensor, export_cpu_tensor,
                                     use_export_mode_v2):
    tmpdir = tempfile.mkdtemp()

    def _input_fn(params):
      return dummy_input_fn(params['batch_size'])

    model_fn = get_model_fn(export_tpu_tensor, export_cpu_tensor)
    run_config = create_run_config(iterations_per_loop=4)
    if use_export_mode_v2:
      export_api_version = tpu_estimator.ExportSavedModelApiVersion.V2

      batch_config = tpu_estimator.BatchConfig(
          num_batch_threads=1,
          max_batch_size=1,
          batch_timeout_micros=100,
          allowed_batch_sizes=[1])

      def tpu_model_fn(features, labels, mode, params):
        if mode == _PREDICT and params['use_tpu']:
          return tpu_estimator.model_fn_inference_on_tpu(
              model_fn, features, labels, mode, params, batch_config)
        else:
          return model_fn(features, labels, mode, params)

      est_model_fn = tpu_model_fn
    else:
      export_api_version = tpu_estimator.ExportSavedModelApiVersion.V1
      est_model_fn = model_fn
    est = tpu_estimator.TPUEstimator(
        model_fn=est_model_fn,
        config=run_config,
        train_batch_size=16,
        export_to_tpu=True,
        export_saved_model_api_version=export_api_version)
    est.train(_input_fn, steps=1)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_saved_model(export_dir_base,
                                        self._serving_input_receiver_fn)

    self._validate_export(export_dir_base, export_dir, export_tpu_tensor,
                          export_cpu_tensor)

    # Clean up.
    gfile.DeleteRecursively(tmpdir)
    def test_eval_metrics_ops_tpu_training_failure(self):
        run_config = create_run_config(iterations_per_loop=2)
        est = tpu_estimator.TPUEstimator(
            model_fn=self._model_fn_with_eval_metric_ops,
            config=run_config,
            train_batch_size=16,
            eval_batch_size=16,
            use_tpu=True,
            # Generates an error on eval, because model_fn(mode=EVAL)
            # has not been split into an eval_metrics_fn.
            eval_on_tpu=True)

        est.train(self._create_input_fn(), steps=1)
        with self.assertRaisesRegex(
                RuntimeError,
                'TPU evaluation must have type`TPUEstimatorSpec`'):
            est.evaluate(self._create_input_fn(), steps=1)
Exemple #15
0
  def _train_and_return_global_steps(self,
                                     iterations_per_loop,
                                     steps=None,
                                     max_steps=None,
                                     pre_train_steps=None,
                                     **kwargs):
    """Trains the model and returns the list of global steps after each loop."""

    def input_fn(params):
      return dummy_input_fn(params['batch_size'])

    def _model_fn(features, labels, mode, params):
      return model_fn_global_step_incrementer(features, labels, mode, params)

    run_config = create_run_config(
        iterations_per_loop=iterations_per_loop,
        num_shards=1,
        num_cores_per_replica=2,
        **kwargs)
    est = tpu_estimator.TPUEstimator(
        model_fn=_model_fn,
        config=run_config,
        train_batch_size=16,
        eval_batch_size=16)

    class _TrainStepCheckHook(session_run_hook.SessionRunHook):
      """Check eval step counter after one session.run."""

      def __init__(self):
        """Constructs the run hook."""
        self._global_steps = []

      @property
      def global_steps(self):
        return self._global_steps

      def after_run(self, run_context, run_values):
        global_step = run_context.session.run(training.get_global_step())
        self._global_steps.append(global_step)

    if pre_train_steps:
      est.train(input_fn, steps=pre_train_steps)

    hook = _TrainStepCheckHook()
    est.train(input_fn, steps=steps, max_steps=max_steps, hooks=[hook])
    return hook.global_steps
  def test_batch_size(self, num_cores_per_replica, num_shards):
    input_fn_call_count = [0]
    run_config = create_run_config(
        iterations_per_loop=4,
        num_cores_per_replica=num_cores_per_replica,
        num_shards=num_shards,
        per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2)

    def _input_fn(params):
      input_fn_call_count[0] += 1
      expected_batch_size = 128 // num_shards
      self.assertEqual(expected_batch_size, params['batch_size'])
      return dummy_input_fn_with_dataset(batch_size=params['batch_size'])

    est = tpu_estimator.TPUEstimator(
        model_fn=model_fn_global_step_incrementer,
        config=run_config,
        train_batch_size=128)
    self.assertEqual(0, input_fn_call_count[0])
    est.train(_input_fn, steps=1)
    self.assertEqual(1, input_fn_call_count[0])
Exemple #17
0
  def _test_eval_steps(self, expected_eval_steps, iterations):

    run_config = create_run_config(
        iterations_per_loop=iterations, num_shards=1, num_cores_per_replica=2)
    est = tpu_estimator.TPUEstimator(
        model_fn=self._model_fn_with_eval_tensor_list,
        config=run_config,
        train_batch_size=16,
        eval_batch_size=16)

    est.train(self._create_input_fn(), steps=1)

    class _EvalStepCheckHook(session_run_hook.SessionRunHook):
      """Check eval step counter after one session.run.

      As the evaluation sets the eval iterations as the eval steps, the
      after_run should be invoked only once.
      """

      def __init__(self, iterations_per_loop, test_case):
        """Constructs the run hook."""
        self._iterations = iterations_per_loop
        self._invoked = False
        self._test_case = test_case

      def before_run(self, run_context):
        return session_run_hook.SessionRunArgs({
            'eval_steps': evaluation._get_or_create_eval_step()
        })

      def after_run(self, run_context, run_values):
        eval_steps = run_values.results['eval_steps']
        self._test_case.assertEqual(expected_eval_steps, eval_steps)
        self._test_case.assertFalse(self._invoked)
        self._invoked = True

    est.evaluate(
        self._create_input_fn(),
        steps=expected_eval_steps,
        hooks=[_EvalStepCheckHook(iterations, self)])
    def _test_eval_steps(self, model_fn, expected_eval_steps, iterations):

        run_config = create_run_config(iterations_per_loop=iterations)
        est = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                         config=run_config,
                                         train_batch_size=16,
                                         eval_batch_size=16)

        est.train(self._create_input_fn(), steps=1)

        class _EvalStepCheckHook(tf.compat.v1.train.SessionRunHook):
            """Check eval step counter after one session.run.

      As the evaluation sets the eval iterations as the eval steps, the
      after_run should be invoked only once.
      """
            def __init__(self, iterations_per_loop, test_case):
                """Constructs the run hook."""
                self._iterations = iterations_per_loop
                self._invoked = False
                self._test_case = test_case

            def before_run(self, run_context):
                del run_context
                # For eval on TPU, the hook should be run only once.
                self._test_case.assertFalse(self._invoked)

            def after_run(self, run_context, run_values):
                # To avoid race condition between the eval step read and increment in
                # evaluation graph, we read the value explicitly here.
                eval_steps = run_context.session.run(
                    evaluation._get_or_create_eval_step())
                self._test_case.assertEqual(expected_eval_steps, eval_steps)
                self._test_case.assertFalse(self._invoked)
                self._invoked = True

        est.evaluate(self._create_input_fn(),
                     steps=expected_eval_steps,
                     hooks=[_EvalStepCheckHook(iterations, self)])
Exemple #19
0
    def test_predict_mode(self):
        input_fn_call_count = [0]
        predict_batch_size = 128
        run_config = create_run_config(
            iterations_per_loop=4,
            num_cores_per_replica=2,
            num_shards=1,
            input_partition_dims=[[1, 2], None],
            per_host_input_for_training=(
                tpu_config.InputPipelineConfig.PER_HOST_V2))

        def _input_fn(params):
            input_fn_call_count[0] += 1
            return dummy_input_fn_with_dataset(batch_size=params['batch_size'],
                                               fea_len=2)

        est = tpu_estimator.TPUEstimator(
            model_fn=model_fn_global_step_incrementer,
            config=run_config,
            train_batch_size=128,
            predict_batch_size=predict_batch_size)

        self.assertEqual(0, input_fn_call_count[0])

        predictor = est.predict(_input_fn, yield_single_examples=False)
        prediction = six.next(predictor)

        self.assertEqual(1, input_fn_call_count[0])
        self.assertIn('predictions', prediction)
        self.assertEqual((predict_batch_size, 1),
                         prediction['predictions'].shape)

        predictor = est.predict(_input_fn, yield_single_examples=True)
        prediction = six.next(predictor)

        self.assertEqual(2, input_fn_call_count[0])
        self.assertIn('predictions', prediction)
        self.assertEqual((1, ), prediction['predictions'].shape)
Exemple #20
0
def get_estimator(use_tpu,
                  output_dir,
                  feature_columns,
                  batch_size,
                  optimizer_type='adagrad',
                  grad_multiplier_fn=None):
  run_config = tpu_config.RunConfig(
      master='',
      model_dir=output_dir,
      session_config=tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False),
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=1,
          num_shards=FLAGS.test_num_shards,
          per_host_input_for_training=(
              tpu_config.InputPipelineConfig.PER_HOST_V2)),
      save_checkpoints_steps=1)

  if optimizer_type == 'adagrad':
    optimization_parameters = tpu_estimator.AdagradParameters(
        LEARNING_RATE,
        ADADGRAD_INIT_VALUE,
        use_gradient_accumulation=False)
  elif optimizer_type == 'sgd':
    optimization_parameters = tpu_estimator.StochasticGradientDescentParameters(
        LEARNING_RATE)

  estimator = tpu_estimator.TPUEstimator(
      model_fn=create_model_fn(feature_columns, optimizer_type),
      use_tpu=use_tpu,
      config=run_config,
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      embedding_config_spec=tpu_estimator.EmbeddingConfigSpec(
          feature_columns=feature_columns,
          optimization_parameters=optimization_parameters,
          experimental_gradient_multiplier_fn=grad_multiplier_fn))
  return estimator
Exemple #21
0
  def test_export_tpu_savedmodel_export_to_tpu_false_eval(self):
    # Test exporting CPU evaulation graph when `export_to_tpu` is `False`.
    tmpdir = tempfile.mkdtemp()
    mode = model_fn_lib.ModeKeys.EVAL

    model_fn = get_model_fn(export_tpu_tensor=True, export_cpu_tensor=True)
    run_config = create_run_config(iterations_per_loop=4)

    def _input_fn(params):
      return dummy_input_fn(params['batch_size'])

    est = tpu_estimator.TPUEstimator(
        model_fn=model_fn,
        config=run_config,
        train_batch_size=16,
        export_to_tpu=False)
    est.train(_input_fn, steps=1)

    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export_no_tpu_eval'))
    export_dir = est.export_saved_model(
        export_dir_base, self._supervised_input_receiver_fn,
        experimental_mode=mode)

    # Check that all the files are in the right places.
    self.assertTrue(gfile.Exists(export_dir_base))

    # Restore, to validate that the export was well-formed.
    tag_set = export_lib.EXPORT_TAG_MAP[mode]
    with ops.Graph().as_default() as graph:
      with session.Session(graph=graph) as sess:
        loader.load(sess, tag_set, export_dir)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertIn('dense/kernel', graph_ops)

    # Clean up.
    gfile.DeleteRecursively(tmpdir)
Exemple #22
0
  def get_activations_and_sequence_lengths(
      self,
      embedding_weights: List[List[float]],
      sparse_ids: tf.SparseTensorValue,
      batch_size: int,
      max_sequence_length: int,
      dimension: int,
      combiner: Text = 'mean',
  ) -> Tuple[tf.Tensor, tf.Tensor]:
    """Gets the activations and seq lengths for a batch of sparse IDs.

    This method uses TPUEstimator and the Feature Column API to get embedding
    activations for a batch of sparse of sparse IDs using a specified set of
    embedding weights.

    Args:
      embedding_weights: The embedding weights as a 2D list of floats.  The
        outer list length is the vocabulary size of the embedding table.  The
        inner list length is the dimension of the embedding weights.
      sparse_ids: The embedding IDs to lookup. This is a 2D SparseTensorValue of
        shape [batch_size, max_sequence_length].
      batch_size: The size of the first dimension of sparse_ids.
      max_sequence_length:  The size of the second dimension of sparse_ids.
      dimension: The embedding dimension size (number of floats for each
        embedding ID).
      combiner: The embedding column combiner (used for multivalent features).

    Returns:
      A tuple containing:
        activations:  The activations for the specified sparse_ids.
          type=float32, shape=[batch_size, max_sequence_length, dimension]
        sequence_lengths: The sequence length of each example.
          type=int64. shape=[batch_size].
    """

    vocab_size = len(embedding_weights)
    categorical_column = (
        tf.feature_column.sequence_categorical_column_with_identity(
            key=self._KEY,
            num_buckets=vocab_size,
        ))

    # Create embedding column initialized with weights provided by caller.
    embedding_column = tf.tpu.experimental.embedding_column(
        categorical_column,
        dimension=dimension,
        max_sequence_length=max_sequence_length,
        initializer=tf.constant_initializer(embedding_weights),
        combiner=combiner,
    )

    # Add an SGD optimizer. This choice is arbitrary for computing activations.
    # It's only required to avoid an undefined gradients error.
    embedding_opt = tf.tpu.experimental.StochasticGradientDescentParameters(.1)
    embedding_config_spec = tpu_estimator.EmbeddingConfigSpec(
        feature_columns=[embedding_column],
        optimization_parameters=embedding_opt,
    )

    def _input_fn(params: Dict[Text, int]) -> tf.data.Dataset:
      """Creates a batched dataset containing the sparse_ids as a feature."""
      # Convert sparse IDs to batched dataset.
      sparse_ids_dataset = tf.data.Dataset.range(1).map(
          lambda x: {self._KEY: tf.SparseTensor.from_value(sparse_ids)})

      # Unbatch and rebatch the dataset based on the batch_size param from
      # TPUEstimator. This is necessary for shape validation performed internal
      # to TPUEstimator.
      return sparse_ids_dataset.unbatch().repeat().batch(params['batch_size'])

    def _host_call(
        concat_activations: tf.Tensor,
        concat_sequence_lengths: tf.Tensor,
    ) -> List[tf.Operation]:
      """Stores the activations and sequence lengths into a summary.

      TPUEstimator will concat the activations and sequence lengths from the
      minibatches on each core along axis=0 and pass them to this host call.
      This host call writes them to a file using the TF summary APIs.

      Args:
        concat_activations: The activations for the global batch. 2D
          Tensor(type=float32, shape=[batch_size, max_sequence_length]).
        concat_sequence_lengths:  The sequence lengths for the global batch. 2D
          Tensor(type=int64, shape=[batch_size, max_sequence_length]).

      Returns:
        A list of summary ops for TPUEstimator to run on the host.
      """
      with contrib_summary.create_file_writer(self._summary_dir).as_default():
        with contrib_summary.always_record_summaries():
          contrib_summary.generic(
              self._SUMMARY_ACTIVATIONS,
              concat_activations,
          )
          contrib_summary.generic(self._SUMMARY_SEQUENCE_LENGTHS,
                                  concat_sequence_lengths)
          return contrib_summary.all_summary_ops()

    def _model_fn(
        features: Dict[Text, tf.Tensor],
        params: Dict[Text, int],
        mode: model_fn_lib.ModeKeys,
    ) -> tpu_estimator.TPUEstimatorSpec:
      """A model which writes activations and sequence lengths to a file.

      This method creates a model to extract the activations and sequence
      lengths on each TPU core and pass them to a host call which writes them
      to a file.

      The model also applies an optimizer to the activations simply to avoid an
      undefined gradients error.

      Args:
        features: A dictionary mapping keys to tensor inputs.
        params: Parameters passed by TPUEstimator.
        mode: Mode can be (TRAIN, EVAL, PREDICT).

      Returns:
        A TPUEstimatorSpec which holds the training_op that TPUEstimator will
        run on TPU and the host_call that TPUEstimator will run on the host.
      """
      del params
      input_layer = tf.keras.experimental.SequenceFeatures([embedding_column])
      activations, sequence_lengths = input_layer(features)
      opt = tf.tpu.CrossShardOptimizer(tf.train.GradientDescentOptimizer(0.1))
      loss = tf.reduce_sum(activations)
      train_op = opt.minimize(loss, global_step=tf.train.get_global_step())

      return tpu_estimator.TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          train_op=train_op,
          host_call=(_host_call, [activations, sequence_lengths]),
      )

    tpu_config = tpu_config_lib.TPUConfig(
        per_host_input_for_training=(
            tpu_config_lib.InputPipelineConfig.PER_HOST_V2),)
    run_config = tpu_config_lib.RunConfig(
        session_config=tf.ConfigProto(isolate_session_state=True),
        tpu_config=tpu_config,
    )
    estimator = tpu_estimator.TPUEstimator(
        model_fn=_model_fn,
        model_dir=self._model_dir,
        use_tpu=True,
        train_batch_size=batch_size,
        eval_batch_size=batch_size,
        config=run_config,
        embedding_config_spec=embedding_config_spec,
    )

    # Train for 1 step and store the activations as summaries.
    estimator.train(_input_fn, steps=1)

    # Read the event summaries and decode the activation tensors.
    output = {}
    for filename in tf.io.gfile.listdir(self._summary_dir):
      filepath = os.path.join(os.path.join(self._summary_dir, filename))
      for event in tf.train.summary_iterator(filepath):
        for v in event.summary.value:
          decoded = tf.io.decode_raw(v.tensor.tensor_content, v.tensor.dtype)
          shape = tf.TensorShape(v.tensor.tensor_shape)
          output[v.tag] = tf.reshape(decoded, shape)
    return (output[self._SUMMARY_ACTIVATIONS],
            output[self._SUMMARY_SEQUENCE_LENGTHS])