def testCustomSaveable(self): export_dir = os.path.join(test.get_temp_dir(), "custom_saveable") builder = saved_model_builder.SavedModelBuilder(export_dir) with session.Session(graph=ops.Graph(), config=config_pb2.ConfigProto( device_count={"CPU": 2})) as sess: # CheckpointedOp is a key-value table that can be saved across sessions. # The table register itself in SAVEABLE_OBJECTS collection. v1 = saver_test_utils.CheckpointedOp(name="v1") variables.global_variables_initializer().run() v1.insert("k1", 3.0).run() # Once the table is restored, we can access it through this reference. ops.add_to_collection("table_ref", v1.table_ref) builder.add_meta_graph_and_variables(sess, ["foo"]) # Save the SavedModel to disk. builder.save() with session.Session(graph=ops.Graph(), config=config_pb2.ConfigProto( device_count={"CPU": 2})) as sess: loader.load(sess, ["foo"], export_dir) # Instantiate a wrapper object from the checkpointed reference. v1 = saver_test_utils.CheckpointedOp( name="v1", table_ref=ops.get_collection("table_ref")[0]) self.assertEqual(b"k1", v1.keys().eval()) self.assertEqual(3.0, v1.values().eval())
def _model_fn_with_saveables_for_export_tests(features, labels, mode): _, _ = features, labels table = saver_test_utils.CheckpointedOp(name='v2') train_op = table.insert('k1', 30.0) prediction = table.lookup('k1', 0.0) return model_fn_lib.EstimatorSpec( mode, predictions=prediction, loss=constant_op.constant(1.), train_op=train_op, export_outputs={ 'test': export.PredictOutput({'prediction': prediction})})
def model_fn_with_trackable(features, labels, mode): spec = model_fn_diff_modes(features, labels, mode) predictions = spec.predictions trackable_variable_ = saver_test_utils.CheckpointedOp(name='v2') if mode == ModeKeys.TRAIN: init_op = trackable_variable_.insert('key1', 2.2) add_to_collection(GraphKeys.TABLE_INITIALIZERS, init_op) else: looked_up = trackable_variable_.lookup('key1', 0.0) predictions = tf.constant([503.0]) + looked_up return model_fn_lib.EstimatorSpec(mode, loss=spec.loss, train_op=spec.train_op, eval_metric_ops=spec.eval_metric_ops, predictions=predictions)