Exemple #1
0
    def testCustomSaveable(self):
        export_dir = os.path.join(test.get_temp_dir(), "custom_saveable")
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        with session.Session(graph=ops.Graph(),
                             config=config_pb2.ConfigProto(
                                 device_count={"CPU": 2})) as sess:
            # CheckpointedOp is a key-value table that can be saved across sessions.
            # The table register itself in SAVEABLE_OBJECTS collection.
            v1 = saver_test_utils.CheckpointedOp(name="v1")
            variables.global_variables_initializer().run()
            v1.insert("k1", 3.0).run()
            # Once the table is restored, we can access it through this reference.
            ops.add_to_collection("table_ref", v1.table_ref)
            builder.add_meta_graph_and_variables(sess, ["foo"])

        # Save the SavedModel to disk.
        builder.save()

        with session.Session(graph=ops.Graph(),
                             config=config_pb2.ConfigProto(
                                 device_count={"CPU": 2})) as sess:
            loader.load(sess, ["foo"], export_dir)
            # Instantiate a wrapper object from the checkpointed reference.
            v1 = saver_test_utils.CheckpointedOp(
                name="v1", table_ref=ops.get_collection("table_ref")[0])
            self.assertEqual(b"k1", v1.keys().eval())
            self.assertEqual(3.0, v1.values().eval())
Exemple #2
0
def _model_fn_with_saveables_for_export_tests(features, labels, mode):
  _, _ = features, labels
  table = saver_test_utils.CheckpointedOp(name='v2')
  train_op = table.insert('k1', 30.0)
  prediction = table.lookup('k1', 0.0)
  return model_fn_lib.EstimatorSpec(
      mode,
      predictions=prediction,
      loss=constant_op.constant(1.),
      train_op=train_op,
      export_outputs={
          'test': export.PredictOutput({'prediction': prediction})})
Exemple #3
0
def model_fn_with_trackable(features, labels, mode):
    spec = model_fn_diff_modes(features, labels, mode)
    predictions = spec.predictions

    trackable_variable_ = saver_test_utils.CheckpointedOp(name='v2')

    if mode == ModeKeys.TRAIN:
        init_op = trackable_variable_.insert('key1', 2.2)
        add_to_collection(GraphKeys.TABLE_INITIALIZERS, init_op)
    else:
        looked_up = trackable_variable_.lookup('key1', 0.0)
        predictions = tf.constant([503.0]) + looked_up

    return model_fn_lib.EstimatorSpec(mode,
                                      loss=spec.loss,
                                      train_op=spec.train_op,
                                      eval_metric_ops=spec.eval_metric_ops,
                                      predictions=predictions)