Exemple #1
0
  def test_scaffold_is_used_for_saver(self):
    tmpdir = tempfile.mkdtemp()

    def _model_fn_scaffold(features, labels, mode):
      _, _ = features, labels
      variables.Variable(1., name='weight')
      real_saver = saver.Saver()
      self.mock_saver = test.mock.Mock(
          wraps=real_saver, saver_def=real_saver.saver_def)
      scores = constant_op.constant([3.])
      return model_fn_lib.EstimatorSpec(
          mode=mode,
          predictions=constant_op.constant([[1.]]),
          loss=constant_op.constant(0.),
          train_op=constant_op.constant(0.),
          scaffold=training.Scaffold(saver=self.mock_saver),
          export_outputs={'test': export.ClassificationOutput(scores)})

    est = estimator.Estimator(model_fn=_model_fn_scaffold)
    est.train(dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    est.export_savedmodel(export_dir_base, serving_input_receiver_fn)

    self.assertTrue(self.mock_saver.restore.called)
Exemple #2
0
  def test_scaffold_is_used_for_saver(self):
    tmpdir = tempfile.mkdtemp()

    def _model_fn_scaffold(features, labels, mode):
      _, _ = features, labels
      variables.Variable(1., name='weight')
      real_saver = saver.Saver()
      self.mock_saver = test.mock.Mock(
          wraps=real_saver, saver_def=real_saver.saver_def)
      scores = constant_op.constant([3.])
      return model_fn_lib.EstimatorSpec(
          mode=mode,
          predictions=constant_op.constant([[1.]]),
          loss=constant_op.constant(0.),
          train_op=constant_op.constant(0.),
          scaffold=training.Scaffold(saver=self.mock_saver),
          export_outputs={'test': export_output.ClassificationOutput(scores)})

    est = estimator.Estimator(model_fn=_model_fn_scaffold)
    est.train(dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    est.export_savedmodel(export_dir_base, serving_input_receiver_fn)

    self.assertTrue(self.mock_saver.restore.called)
Exemple #3
0
  def test_complete_flow_with_a_simple_linear_model(self):

    def _model_fn(features, labels, mode):
      predictions = layers.dense(
          features['x'], 1, kernel_initializer=init_ops.zeros_initializer())
      export_outputs = {
          'predictions': export.RegressionOutput(predictions)
      }

      if mode == model_fn_lib.ModeKeys.PREDICT:
        return model_fn_lib.EstimatorSpec(
            mode, predictions=predictions, export_outputs=export_outputs)

      loss = losses.mean_squared_error(labels, predictions)
      train_op = training.GradientDescentOptimizer(learning_rate=0.5).minimize(
          loss, training.get_global_step())
      eval_metric_ops = {
          'absolute_error': metrics_lib.mean_absolute_error(
              labels, predictions)
      }

      return model_fn_lib.EstimatorSpec(
          mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs)

    est = estimator.Estimator(model_fn=_model_fn)
    data = np.linspace(0., 1., 100, dtype=np.float32).reshape(-1, 1)

    # TRAIN
    # learn x = y
    train_input_fn = numpy_io.numpy_input_fn(
        x={'x': data}, y=data, batch_size=50, num_epochs=None, shuffle=True)
    est.train(train_input_fn, steps=200)

    # EVALUTE
    eval_input_fn = numpy_io.numpy_input_fn(
        x={'x': data}, y=data, batch_size=50, num_epochs=1, shuffle=True)
    scores = est.evaluate(eval_input_fn)
    self.assertEqual(200, scores['global_step'])
    self.assertGreater(0.1, scores['absolute_error'])

    # PREDICT
    predict_input_fn = numpy_io.numpy_input_fn(
        x={'x': data}, y=None, batch_size=10, num_epochs=1, shuffle=False)
    predictions = list(est.predict(predict_input_fn))
    self.assertAllClose(data, predictions, atol=0.01)

    # EXPORT
    feature_spec = {'x': parsing_ops.FixedLenFeature([1], dtypes.float32)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)
    export_dir = est.export_savedmodel(tempfile.mkdtemp(),
                                       serving_input_receiver_fn)
    self.assertTrue(gfile.Exists(export_dir))
    def test_build_parsing_serving_input_receiver_fn(self):
        feature_spec = {
            'int_feature': parsing_ops.VarLenFeature(dtypes.int64),
            'float_feature': parsing_ops.VarLenFeature(dtypes.float32)
        }
        serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
            feature_spec)
        with ops.Graph().as_default():
            serving_input_receiver = serving_input_receiver_fn()
            self.assertEqual(set(['int_feature', 'float_feature']),
                             set(serving_input_receiver.features.keys()))
            self.assertEqual(
                set(['examples']),
                set(serving_input_receiver.receiver_tensors.keys()))

            example = example_pb2.Example()
            text_format.Parse(
                "features: { "
                "  feature: { "
                "    key: 'int_feature' "
                "    value: { "
                "      int64_list: { "
                "        value: [ 21, 2, 5 ] "
                "      } "
                "    } "
                "  } "
                "  feature: { "
                "    key: 'float_feature' "
                "    value: { "
                "      float_list: { "
                "        value: [ 525.25 ] "
                "      } "
                "    } "
                "  } "
                "} ", example)

            with self.test_session() as sess:
                sparse_result = sess.run(
                    serving_input_receiver.features,
                    feed_dict={
                        serving_input_receiver.receiver_tensors['examples'].name:
                        [example.SerializeToString()]
                    })
                self.assertAllEqual([[0, 0], [0, 1], [0, 2]],
                                    sparse_result['int_feature'].indices)
                self.assertAllEqual([21, 2, 5],
                                    sparse_result['int_feature'].values)
                self.assertAllEqual([[0, 0]],
                                    sparse_result['float_feature'].indices)
                self.assertAllEqual([525.25],
                                    sparse_result['float_feature'].values)
    def test_export_savedmodel_proto_roundtrip(self):
        tmpdir = tempfile.mkdtemp()
        est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
        est.train(input_fn=dummy_input_fn, steps=1)
        feature_spec = {
            'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
            'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)
        }
        serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
            feature_spec)

        # Perform the export.
        export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('export'))
        export_dir = est.export_savedmodel(export_dir_base,
                                           serving_input_receiver_fn)

        # Check that all the files are in the right places.
        self.assertTrue(gfile.Exists(export_dir_base))
        self.assertTrue(gfile.Exists(export_dir))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('saved_model.pb'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('variables'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('variables/variables.index'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes(
                        'variables/variables.data-00000-of-00001'))))

        # Restore, to validate that the export was well-formed.
        with ops.Graph().as_default() as graph:
            with session.Session(graph=graph) as sess:
                loader.load(sess, [tag_constants.SERVING], export_dir)
                graph_ops = [x.name for x in graph.get_operations()]
                self.assertTrue('input_example_tensor' in graph_ops)
                self.assertTrue('ParseExample/ParseExample' in graph_ops)
                self.assertTrue('weight' in graph_ops)

        # Clean up.
        gfile.DeleteRecursively(tmpdir)
    def test_scaffold_is_used_for_local_init(self):
        tmpdir = tempfile.mkdtemp()

        def _model_fn_scaffold(features, labels, mode):
            _, _ = features, labels
            my_int = variables.Variable(
                1, name='my_int', collections=[ops.GraphKeys.LOCAL_VARIABLES])
            scores = constant_op.constant([3.])
            with ops.control_dependencies([
                    variables.local_variables_initializer(),
                    data_flow_ops.tables_initializer()
            ]):
                assign_op = state_ops.assign(my_int, 12345)

            # local_initSop must be an Operation, not a Tensor.
            custom_local_init_op = control_flow_ops.group(assign_op)
            return model_fn_lib.EstimatorSpec(
                mode=mode,
                predictions=constant_op.constant([[1.]]),
                loss=constant_op.constant(0.),
                train_op=constant_op.constant(0.),
                scaffold=training.Scaffold(local_init_op=custom_local_init_op),
                export_outputs={
                    'test': export_output.ClassificationOutput(scores)
                })

        est = estimator.Estimator(model_fn=_model_fn_scaffold)
        est.train(dummy_input_fn, steps=1)
        feature_spec = {
            'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
            'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)
        }
        serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
            feature_spec)

        # Perform the export.
        export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('export'))
        export_dir = est.export_savedmodel(export_dir_base,
                                           serving_input_receiver_fn)

        # Restore, to validate that the custom local_init_op runs.
        with ops.Graph().as_default() as graph:
            with session.Session(graph=graph) as sess:
                loader.load(sess, [tag_constants.SERVING], export_dir)
                my_int = graph.get_tensor_by_name('my_int:0')
                my_int_value = sess.run(my_int)
                self.assertEqual(12345, my_int_value)
Exemple #7
0
  def test_build_parsing_serving_input_receiver_fn(self):
    feature_spec = {"int_feature": parsing_ops.VarLenFeature(dtypes.int64),
                    "float_feature": parsing_ops.VarLenFeature(dtypes.float32)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)
    with ops.Graph().as_default():
      serving_input_receiver = serving_input_receiver_fn()
      self.assertEqual(set(["int_feature", "float_feature"]),
                       set(serving_input_receiver.features.keys()))
      self.assertEqual(set(["examples"]),
                       set(serving_input_receiver.receiver_tensors.keys()))

      example = example_pb2.Example()
      text_format.Parse("features: { "
                        "  feature: { "
                        "    key: 'int_feature' "
                        "    value: { "
                        "      int64_list: { "
                        "        value: [ 21, 2, 5 ] "
                        "      } "
                        "    } "
                        "  } "
                        "  feature: { "
                        "    key: 'float_feature' "
                        "    value: { "
                        "      float_list: { "
                        "        value: [ 525.25 ] "
                        "      } "
                        "    } "
                        "  } "
                        "} ", example)

      with self.test_session() as sess:
        sparse_result = sess.run(
            serving_input_receiver.features,
            feed_dict={
                serving_input_receiver.receiver_tensors["examples"].name:
                [example.SerializeToString()]})
        self.assertAllEqual([[0, 0], [0, 1], [0, 2]],
                            sparse_result["int_feature"].indices)
        self.assertAllEqual([21, 2, 5],
                            sparse_result["int_feature"].values)
        self.assertAllEqual([[0, 0]],
                            sparse_result["float_feature"].indices)
        self.assertAllEqual([525.25],
                            sparse_result["float_feature"].values)
Exemple #8
0
  def test_scaffold_is_used_for_local_init(self):
    tmpdir = tempfile.mkdtemp()

    def _model_fn_scaffold(features, labels, mode):
      _, _ = features, labels
      my_int = variables.Variable(1, name='my_int',
                                  collections=[ops.GraphKeys.LOCAL_VARIABLES])
      scores = constant_op.constant([3.])
      with ops.control_dependencies(
          [variables.local_variables_initializer(),
           data_flow_ops.tables_initializer()]):
        assign_op = state_ops.assign(my_int, 12345)

      # local_initSop must be an Operation, not a Tensor.
      custom_local_init_op = control_flow_ops.group(assign_op)
      return model_fn_lib.EstimatorSpec(
          mode=mode,
          predictions=constant_op.constant([[1.]]),
          loss=constant_op.constant(0.),
          train_op=constant_op.constant(0.),
          scaffold=training.Scaffold(local_init_op=custom_local_init_op),
          export_outputs={'test': export_output.ClassificationOutput(scores)})

    est = estimator.Estimator(model_fn=_model_fn_scaffold)
    est.train(dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(export_dir_base,
                                       serving_input_receiver_fn)

    # Restore, to validate that the custom local_init_op runs.
    with ops.Graph().as_default() as graph:
      with session.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        my_int = graph.get_tensor_by_name('my_int:0')
        my_int_value = sess.run(my_int)
        self.assertEqual(12345, my_int_value)
Exemple #9
0
  def test_export_savedmodel_proto_roundtrip(self):
    tmpdir = tempfile.mkdtemp()
    est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
    est.train(input_fn=dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(
        export_dir_base, serving_input_receiver_fn)

    # Check that all the files are in the right places.
    self.assertTrue(gfile.Exists(export_dir_base))
    self.assertTrue(gfile.Exists(export_dir))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('saved_model.pb'))))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('variables'))))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('variables/variables.index'))))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('variables/variables.data-00000-of-00001'))))

    # Restore, to validate that the export was well-formed.
    with ops.Graph().as_default() as graph:
      with session.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertTrue('input_example_tensor' in graph_ops)
        self.assertTrue('ParseExample/ParseExample' in graph_ops)
        self.assertTrue('weight' in graph_ops)

    # Clean up.
    gfile.DeleteRecursively(tmpdir)
    def test_export_savedmodel_extra_assets(self):
        tmpdir = tempfile.mkdtemp()
        est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
        est.train(input_fn=dummy_input_fn, steps=1)
        feature_spec = {
            'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
            'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)
        }
        serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
            feature_spec)

        # Create a fake asset.
        extra_file_name = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('my_extra_file'))
        extra_file = gfile.GFile(extra_file_name, mode='w')
        extra_file.write(_EXTRA_FILE_CONTENT)
        extra_file.close()

        # Perform the export.
        assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}
        export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('export'))
        export_dir = est.export_savedmodel(export_dir_base,
                                           serving_input_receiver_fn,
                                           assets_extra=assets_extra)

        # Check that the asset files are in the right places.
        expected_extra_path = os.path.join(
            compat.as_bytes(export_dir),
            compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('assets.extra'))))
        self.assertTrue(gfile.Exists(expected_extra_path))
        self.assertEqual(
            compat.as_bytes(_EXTRA_FILE_CONTENT),
            compat.as_bytes(gfile.GFile(expected_extra_path).read()))

        # cleanup
        gfile.DeleteRecursively(tmpdir)
Exemple #11
0
  def test_export_savedmodel_extra_assets(self):
    tmpdir = tempfile.mkdtemp()
    est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
    est.train(input_fn=dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Create a fake asset.
    extra_file_name = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('my_extra_file'))
    extra_file = gfile.GFile(extra_file_name, mode='w')
    extra_file.write(_EXTRA_FILE_CONTENT)
    extra_file.close()

    # Perform the export.
    assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(export_dir_base,
                                       serving_input_receiver_fn,
                                       assets_extra=assets_extra)

    # Check that the asset files are in the right places.
    expected_extra_path = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir), compat.as_bytes('assets.extra'))))
    self.assertTrue(gfile.Exists(expected_extra_path))
    self.assertEqual(
        compat.as_bytes(_EXTRA_FILE_CONTENT),
        compat.as_bytes(gfile.GFile(expected_extra_path).read()))

    # cleanup
    gfile.DeleteRecursively(tmpdir)
Exemple #12
0
  def test_export_savedmodel_assets(self):
    tmpdir = tempfile.mkdtemp()
    est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
    est.train(input_fn=dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Create a fake asset.
    vocab_file_name = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file'))
    vocab_file = gfile.GFile(vocab_file_name, mode='w')
    vocab_file.write(_VOCAB_FILE_CONTENT)
    vocab_file.close()

    # hack in an op that uses the asset, in order to test asset export.
    # this is not actually valid, of course.
    def serving_input_receiver_with_asset_fn():
      features, receiver_tensor = serving_input_receiver_fn()
      filename = ops.convert_to_tensor(vocab_file_name,
                                       dtypes.string,
                                       name='asset_filepath')
      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
      features['bogus_filename'] = filename

      return export.ServingInputReceiver(features, receiver_tensor)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(
        export_dir_base, serving_input_receiver_with_asset_fn)

    # Check that the asset files are in the right places.
    expected_vocab_file_name = os.path.join(
        compat.as_bytes(export_dir), compat.as_bytes('assets/my_vocab_file'))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir), compat.as_bytes('assets'))))
    self.assertTrue(gfile.Exists(expected_vocab_file_name))
    self.assertEqual(
        compat.as_bytes(_VOCAB_FILE_CONTENT),
        compat.as_bytes(gfile.GFile(expected_vocab_file_name).read()))

    # Restore, to validate that the export was well-formed.
    with ops.Graph().as_default() as graph:
      with session.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        assets = [
            x.eval()
            for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
        ]
        self.assertItemsEqual([vocab_file_name], assets)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertTrue('input_example_tensor' in graph_ops)
        self.assertTrue('ParseExample/ParseExample' in graph_ops)
        self.assertTrue('asset_filepath' in graph_ops)
        self.assertTrue('weight' in graph_ops)

    # cleanup
    gfile.DeleteRecursively(tmpdir)
Exemple #13
0
  def test_export_savedmodel_assets(self):
    tmpdir = tempfile.mkdtemp()
    est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
    est.train(input_fn=dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Create a fake asset.
    vocab_file_name = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file'))
    vocab_file = gfile.GFile(vocab_file_name, mode='w')
    vocab_file.write(_VOCAB_FILE_CONTENT)
    vocab_file.close()

    # hack in an op that uses the asset, in order to test asset export.
    # this is not actually valid, of course.
    def serving_input_receiver_with_asset_fn():
      features, receiver_tensor = serving_input_receiver_fn()
      filename = ops.convert_to_tensor(vocab_file_name,
                                       dtypes.string,
                                       name='asset_filepath')
      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
      features['bogus_filename'] = filename

      return export.ServingInputReceiver(features, receiver_tensor)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(
        export_dir_base, serving_input_receiver_with_asset_fn)

    # Check that the asset files are in the right places.
    expected_vocab_file_name = os.path.join(
        compat.as_bytes(export_dir), compat.as_bytes('assets/my_vocab_file'))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir), compat.as_bytes('assets'))))
    self.assertTrue(gfile.Exists(expected_vocab_file_name))
    self.assertEqual(
        compat.as_bytes(_VOCAB_FILE_CONTENT),
        compat.as_bytes(gfile.GFile(expected_vocab_file_name).read()))

    # Restore, to validate that the export was well-formed.
    with ops.Graph().as_default() as graph:
      with session.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        assets = [
            x.eval()
            for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
        ]
        self.assertItemsEqual([vocab_file_name], assets)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertTrue('input_example_tensor' in graph_ops)
        self.assertTrue('ParseExample/ParseExample' in graph_ops)
        self.assertTrue('asset_filepath' in graph_ops)
        self.assertTrue('weight' in graph_ops)

    # cleanup
    gfile.DeleteRecursively(tmpdir)