def testLegacyInitOp(self):
    export_dir = self._get_export_dir("test_legacy_init_op")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      # Initialize another variable `v3` to 42.
      v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
      ops.add_to_collection("v", v3)

      # Set up an assignment op to be run as part of the legacy_init_op.
      assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
      legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")

      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], legacy_init_op=legacy_init_op)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      # Evaluates to the sum of the first two variables and assigned as part of
      # the legacy_init_op, following a restore.
      self.assertEqual(3, ops.get_collection("v")[2].eval())
  def testCustomMainOp(self):
    export_dir = self._get_export_dir("test_main_op")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      # Initialize another variable `v3` to 42.
      v3 = variables.Variable(42, name="v3")
      ops.add_to_collection("v", v3)

      # Set up an assignment op to be run as part of the main_op.
      with ops.control_dependencies([main_op.main_op()]):
        add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
        custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))

      sess.run(custom_main_op)
      builder.add_meta_graph_and_variables(
          sess, ["foo"], main_op=custom_main_op)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      # Evaluates to the sum of the first two variables and assigned as part of
      # the main_op, following a restore.
      self.assertEqual(3, ops.get_collection("v")[2].eval())
  def testGraphWithoutVariables(self):
    export_dir = self._get_export_dir("test_graph_has_variables")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with no variables.
    with self.test_session(graph=ops.Graph()) as sess:
      constant_5_name = constant_op.constant(5.0).name
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Second graph with no variables
    with self.test_session(graph=ops.Graph()) as sess:
      constant_6_name = constant_op.constant(6.0).name
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo".
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      # Read the constant a from the graph.
      a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
      b = constant_op.constant(6.0)
      c = a * b
      self.assertEqual(30.0, sess.run(c))

    # Restore the graph with tag "bar".
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      # Read the constant a from the graph.
      a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
      b = constant_op.constant(5.0)
      c = a * b
      self.assertEqual(30.0, sess.run(c))
  def testTrainOpGroup(self):
    export_dir = self._get_export_dir("test_train_op_group")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      sess.run(variables.global_variables_initializer())
      train_op = control_flow_ops.group()

      sess.run(train_op)
      # TODO(karmel): remove explicit call when in the public method.
      builder._add_train_op(train_op)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      self.assertIsInstance(
          ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
  def testTrainOpAfterVariables(self):
    export_dir = self._get_export_dir("test_train_op_after_variables")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(sess, ["pre_foo"])

      train_op = state_ops.assign_add(v1, v2)
      sess.run(train_op)
      # TODO(karmel): remove explicit call when in the public method.
      builder._add_train_op(train_op)
      builder.add_meta_graph(["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertIsInstance(
          ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["pre_foo"], export_dir)
      self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
  def testSaveAsText(self):
    export_dir = os.path.join(
        compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("astext"))
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(42, name="v")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(42, v.eval())
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with the same single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(43, name="v")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(43, v.eval())
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

    # Restore the graph with tag "bar", whose variables were not saved.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
  def testCustomSaveable(self):
    export_dir = self._get_export_dir("custom_saveable")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with session.Session(
        graph=ops.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      # CheckpointedOp is a key-value table that can be saved across sessions.
      # The table register itself in SAVEABLE_OBJECTS collection.
      v1 = saver_test_utils.CheckpointedOp(name="v1")
      variables.global_variables_initializer().run()
      v1.insert("k1", 3.0).run()
      # Once the table is restored, we can access it through this reference.
      ops.add_to_collection("table_ref", v1.table_ref)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with session.Session(
        graph=ops.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      loader.load(sess, ["foo"], export_dir)
      # Instantiate a wrapper object from the checkpointed reference.
      v1 = saver_test_utils.CheckpointedOp(
          name="v1", table_ref=ops.get_collection("table_ref")[0])
      self.assertEqual(b"k1", v1.keys().eval())
      self.assertEqual(3.0, v1.values().eval())
  def testSaveAsText(self):
    export_dir = self._get_export_dir("test_astext")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with the same single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Restore the graph with tag "bar", whose variables were not saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
 def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
   """A wrapper to export to SavedModel, and convert it to other formats."""
   result_dir = base_strategy.export(estimator, export_dir,
                                     checkpoint_path,
                                     eval_result)
   with ops.Graph().as_default() as graph:
     with tf_session.Session(graph=graph) as sess:
       saved_model_loader.load(
           sess, [tag_constants.SERVING], result_dir)
       # Note: This is GTFlow internal API and might change.
       ensemble_model = graph.get_operation_by_name(
           "ensemble_model/TreeEnsembleSerialize")
       _, dfec_str = sess.run(ensemble_model.outputs)
       dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
       dtec.ParseFromString(dfec_str)
       # Export the result in the same folder as the saved model.
       if convert_fn:
         convert_fn(dtec, sorted_feature_names,
                    len(dense_floats),
                    len(sparse_float_indices),
                    len(sparse_int_indices), result_dir, eval_result)
       feature_importances = _get_feature_importances(
           dtec, sorted_feature_names,
           len(dense_floats),
           len(sparse_float_indices), len(sparse_int_indices))
       sorted_by_importance = sorted(
           feature_importances.items(), key=lambda x: -x[1])
       assets_dir = os.path.join(result_dir, "assets.extra")
       gfile.MakeDirs(assets_dir)
       with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
                        "w") as f:
         f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
   return result_dir
  def testSignatureDefs(self):
    export_dir = self._get_export_dir("test_signature_defs")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable and a single entry in the signature def map.
    # SavedModel is invoked to add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      # Build and populate an empty SignatureDef for testing.
      foo_signature = signature_def_utils.build_signature_def(dict(),
                                                              dict(), "foo")
      builder.add_meta_graph_and_variables(
          sess, ["foo"], signature_def_map={"foo_key": foo_signature})

    # Graph with the same single variable and multiple entries in the signature
    # def map. No weights are saved by SavedModel.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      # Build and populate a different SignatureDef for testing.
      bar_signature = signature_def_utils.build_signature_def(dict(),
                                                              dict(), "bar")
      # Also, build a different SignatureDef corresponding to "foo_key" defined
      # in the previous graph.
      foo_new_signature = signature_def_utils.build_signature_def(dict(),
                                                                  dict(),
                                                                  "foo_new")
      builder.add_meta_graph(
          ["bar"],
          signature_def_map={
              "bar_key": bar_signature,
              "foo_key": foo_new_signature
          })

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo". The single entry in the SignatureDef map
    # corresponding to "foo_key" should exist.
    with self.test_session(graph=ops.Graph()) as sess:
      foo_graph = loader.load(sess, ["foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

      foo_signature = foo_graph.signature_def
      self.assertEqual(len(foo_signature), 1)
      self.assertEqual("foo", foo_signature["foo_key"].method_name)

    # Restore the graph with tag "bar". The SignatureDef map should have two
    # entries. One corresponding to "bar_key" and another corresponding to the
    # new value of "foo_key".
    with self.test_session(graph=ops.Graph()) as sess:
      bar_graph = loader.load(sess, ["bar"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

      bar_signature = bar_graph.signature_def
      self.assertEqual(len(bar_signature), 2)
      self.assertEqual("bar", bar_signature["bar_key"].method_name)
      self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
  def testVariables(self):
    export_dir = os.path.join(
        compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("variables"))
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with two variables. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=tf.Graph()) as sess:
      v1 = tf.Variable(1, name="v1")
      v2 = tf.Variable(2, name="v2")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(1, v1.eval())
      self.assertEqual(2, v2.eval())
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with a single variable (subset of the variables from the previous
    # graph whose weights were saved). SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=tf.Graph()) as sess:
      v2 = tf.Variable(3, name="v2")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(3, v2.eval())
      builder.add_meta_graph(["bar"])

    # Graph with a single variable (disjoint set of variables from the previous
    # graph whose weights were saved). SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=tf.Graph()) as sess:
      v3 = tf.Variable(4, name="v3")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(4, v3.eval())
      builder.add_meta_graph(["baz"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      collection_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
      self.assertEqual(len(collection_vars), 2)
      self.assertEqual(1, collection_vars[0].eval())
      self.assertEqual(2, collection_vars[1].eval())

    # Restore the graph with tag "bar", whose variables were not saved. Only the
    # subset of the variables added to the graph will be restored with the
    # checkpointed value.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      collection_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
      self.assertEqual(len(collection_vars), 1)
      self.assertEqual(2, collection_vars[0].eval())

    # Try restoring the graph with tag "baz", whose variables were not saved.
    # Since this graph has a disjoint set of variables from the set that was
    # saved, this should raise an error.
    with self.test_session(graph=tf.Graph()) as sess:
      self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
                        export_dir)
  def testClearExtraneousSavers(self):
    export_dir = os.path.join(test.get_temp_dir(),
                              "test_clear_extraneous_savers")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Create a variable and a Saver.
    with ops.Graph().as_default() as graph:
      with session.Session(
          target="",
          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
        self._init_and_validate_variable(sess, "v", 42)

        # Add two Savers, which should be removed in
        # add_meta_graph_and_variables() in favor of the locally added one.
        saver1 = tf_saver.Saver()
        graph.add_to_collection(ops.GraphKeys.SAVERS, saver1)
        saver2 = tf_saver.Saver()
        graph.add_to_collection(ops.GraphKeys.SAVERS, saver2)

        # Confirm there are two SaverDefs.
        savers = graph.get_collection(ops.GraphKeys.SAVERS)
        self.assertEqual(2, len(savers))

        # Confirm there are two Save and two Restore ops.
        save_op_names = set([x.name for x in graph.get_operations()
                             if x.type == "SaveV2"])
        self.assertSetEqual(set(["save/SaveV2", "save_1/SaveV2"]),
                            save_op_names)

        restore_op_names = set([x.name for x in graph.get_operations()
                                if x.type == "RestoreV2"])
        self.assertSetEqual(set(["save/RestoreV2", "save_1/RestoreV2"]),
                            restore_op_names)

        # The SavedModel builder adds its own Saver' for a total of three.
        builder.add_meta_graph_and_variables(
            sess, [tag_constants.TRAINING], clear_devices=True)

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph.
    with ops.Graph().as_default() as graph:
      with self.test_session(graph=graph) as sess:
        loader.load(sess, [tag_constants.TRAINING], export_dir)
        self.assertEqual(
            42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

        # Confirm that the reloaded graph has only one SaverDef.
        savers = ops.get_collection(ops.GraphKeys.SAVERS)
        self.assertEqual(1, len(savers))

        # The reloaded graph should have exactly one Save and one Restore op.
        save_op_names = set([x.name for x in graph.get_operations()
                             if x.type == "SaveV2"])
        self.assertSetEqual(set(["save_2/SaveV2"]), save_op_names)
        restore_op_names = set([x.name for x in graph.get_operations()
                                if x.type == "RestoreV2"])
        self.assertSetEqual(set(["save_2/RestoreV2"]), restore_op_names)
  def testStripDefaultAttrsInconsistentConsumerDefaults(self):
    if ops._USE_C_API: return  # TODO(skyewm): get this working

    export_dir = self._get_export_dir(
        "test_strip_default_attrs_no_consumer_defaults")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Add a graph with two float32 variables and a Complex Op composing them
    # with strip_default_attrs enabled. This must remove the following
    # defaults for the "Complex" Op:
    #   o "T"    : float32.   (input type)
    #   o "Tout" : complex64. (output type)
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], strip_default_attrs=True)

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Update the Op registry to remove defaults for all attrs("T", "Tout") from
    # the "Complex" OpDef.
    complex_op_def = op_def_registry.get_registered_ops()["Complex"]
    original_complex_op_def = op_def_pb2.OpDef()
    original_complex_op_def.CopyFrom(complex_op_def)
    for attr_def in complex_op_def.attr:
      attr_def.ClearField("default_value")

    # Loading the SavedModel via the loader must fail because the SavedModel
    # does not have any attr values for the "Complex" node and the current
    # op registry does not have have any default values for the "Complex" op.
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        ValueError,
        "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
      loader.load(sess, ["foo"], export_dir)

    # Update the Op registry to change the defaults for attr "Tout"
    # (complex64 -> complex128).
    complex_op_def.CopyFrom(original_complex_op_def)
    for attr_def in complex_op_def.attr:
      if attr_def.name == "Tout":
        attr_def.default_value.type = types_pb2.DT_COMPLEX128

    # Loading the SavedModel via the loader must set "Tout" attr_value for the
    # "Complex" node according to the latest defaults (complex128). This is
    # expected to fail the model import as there is no OpKernel registered to
    # handle attrs "T" (float32) and "Tout" (complex128).
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        ".*No OpKernel was registered to support Op \'Complex\' with these "
        "attrs..*"):
      loader.load(sess, ["foo"], export_dir)
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
                       output_arrays, tag_set, signature_key):
  """Converts a SavedModel to a frozen graph.

  Args:
    saved_model_dir: SavedModel directory to convert.
    input_arrays: List of input tensors to freeze graph with. Uses input arrays
      from SignatureDef when none are provided.
    input_shapes: Dict of strings representing input tensor names to list of
      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
      Automatically determined when input shapes is None (e.g., {"foo" : None}).
    output_arrays: List of output tensors to freeze graph with. Uses output
      arrays from SignatureDef when none are provided.
    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze. All tags in the tag set must be present.
    signature_key: Key identifying SignatureDef containing inputs and outputs.

  Returns:
    frozen_graph_def: Frozen GraphDef.
    in_tensors: List of input tensors for the graph.
    out_tensors: List of output tensors for the graph.

  Raises:
    ValueError:
      SavedModel doesn't contain a MetaGraphDef identified by tag_set.
      signature_key is not in the MetaGraphDef.
      assets/ directory is in the MetaGraphDef.
      input_shapes does not match the length of input_arrays.
      input_arrays or output_arrays are not valid.
  """
  # Read SignatureDef.
  meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
  signature_def = _get_signature_def(meta_graph, signature_key)
  inputs, outputs = _get_inputs_outputs(signature_def)

  # Check SavedModel for assets directory.
  collection_def = meta_graph.collection_def
  if constants.ASSETS_KEY in collection_def:
    raise ValueError("SavedModels with assets/ directory are not supported.")

  graph = ops.Graph()
  with session.Session(graph=graph) as sess:
    loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)

    # Gets input and output tensors.
    # TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
    in_tensors = _get_tensors(graph, inputs, input_arrays)
    out_tensors = _get_tensors(graph, outputs, output_arrays)
    set_tensor_shapes(in_tensors, input_shapes)

    output_names = [node.split(":")[0] for node in outputs]
    frozen_graph_def = tf_graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), output_names)

    return frozen_graph_def, in_tensors, out_tensors
  def __init__(self,
               export_dir,
               signature_def_key=None,
               signature_def=None,
               input_names=None,
               output_names=None,
               tags=None,
               graph=None):
    """Initialize a `CoreEstimatorPredictor`.

    Args:
      export_dir: a path to a directory containing a `SavedModel`.
      signature_def_key: Optional string specifying the signature to use. If
        `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of
        `signature_def_key` and `signature_def` should be specified.
      signature_def: A `SignatureDef` proto specifying the inputs and outputs
        for prediction. Only one of `signature_def_key` and `signature_def`
        should be specified.
      input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel`
        that represent the input. The keys can be any string of the user's
        choosing.
      output_names: A dictionary mapping strings to `Tensor`s in the
        `SavedModel` that represent the output. The keys can be any string of
        the user's choosing.
      tags: Optional. Tags that will be used to retrieve the correct
        `SignatureDef`. Defaults to `DEFAULT_TAGS`.
      graph: Optional. The Tensorflow `graph` in which prediction should be
        done.
    Raises:
      ValueError: If more than one of signature_def_key OR signature_def OR
        (input_names AND output_names) is specified.
    """
    _check_signature_arguments(
        signature_def_key, signature_def, input_names, output_names)
    tags = tags or DEFAULT_TAGS
    self._graph = graph or ops.Graph()

    with self._graph.as_default():
      self._session = session.Session()
      loader.load(self._session, tags.split(','), export_dir)

    if input_names is None:
      if signature_def is None:
        signature_def = _get_signature_def(signature_def_key, export_dir, tags)
      input_names = {k: v.name for k, v in signature_def.inputs.items()}
      output_names = {k: v.name for k, v in signature_def.outputs.items()}

    self._feed_tensors = {k: self._graph.get_tensor_by_name(v)
                          for k, v in input_names.items()}
    self._fetch_tensors = {k: self._graph.get_tensor_by_name(v)
                           for k, v in output_names.items()}
  def testVariables(self):
    export_dir = self._get_export_dir("test_variables")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with two variables. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v1", 1)
      self._init_and_validate_variable(sess, "v2", 2)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with a single variable (subset of the variables from the previous
    # graph whose weights were saved). SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v2", 3)
      builder.add_meta_graph(["bar"])

    # Graph with a single variable (disjoint set of variables from the previous
    # graph whose weights were saved). SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v3", 4)
      builder.add_meta_graph(["baz"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      self.assertEqual(len(collection_vars), 2)
      self.assertEqual(1, collection_vars[0].eval())
      self.assertEqual(2, collection_vars[1].eval())

    # Restore the graph with tag "bar", whose variables were not saved. Only the
    # subset of the variables added to the graph will be restored with the
    # checkpointed value.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      self.assertEqual(len(collection_vars), 1)
      self.assertEqual(2, collection_vars[0].eval())

    # Try restoring the graph with tag "baz", whose variables were not saved.
    # Since this graph has a disjoint set of variables from the set that was
    # saved, this should raise an error.
    with self.test_session(graph=ops.Graph()) as sess:
      self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
                        export_dir)
  def _TestStaticOp(self, use_function_backup):
    if not is_tensorrt_enabled():
      return

    tmp_dir = self.get_temp_dir()
    input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
    output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
    self._WriteInputSavedModel(input_saved_model_dir)
    output_graph_def = self._ConvertGraph(
        input_saved_model_dir=input_saved_model_dir,
        output_saved_model_dir=output_saved_model_dir,
        maximum_cached_engines=2,  # This is noop, added just for testing.
        use_function_backup=use_function_backup)

    # Test the output GraphDef.
    with ops.Graph().as_default():
      importer.import_graph_def(output_graph_def, name="")
      with self.session(config=self._GetConfigProto()) as sess:
        # Run with batch size 1, the default engine embedded in the graphdef
        # will be used.
        self._TestRun(
            sess,
            1,
            use_function_backup=use_function_backup,
            expect_engine_is_run=True)
        # Run with batch size 2, which exceed the max_batch_size, it should try
        # to fall back to TF function.
        self._TestRun(
            sess,
            2,
            use_function_backup=use_function_backup,
            expect_engine_is_run=False)

    # Test the output SavedModel
    with ops.Graph().as_default():
      with self.session(config=self._GetConfigProto()) as sess:
        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
        # Run with batch size 1, the default engine embedded in the graphdef
        # will be used.
        self._TestRun(
            sess,
            1,
            use_function_backup=use_function_backup,
            expect_engine_is_run=True)
        # Run with batch size 2, which exceed the max_batch_size, it should try
        # to fall back to TF function.
        self._TestRun(
            sess,
            2,
            use_function_backup=use_function_backup,
            expect_engine_is_run=False)
  def testCollections(self):
    export_dir = os.path.join(
        compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("collections"))
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable added to a collection. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(42, name="v")
      tf.add_to_collection("foo_vars", v)
      sess.run(tf.initialize_all_variables())
      self.assertEqual(42, v.eval())
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with the same single variable added to a different collection.
    # SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(43, name="v")
      tf.add_to_collection("bar_vars", v)
      sess.run(tf.initialize_all_variables())
      self.assertEqual(43, v.eval())
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo", whose variables were saved. The
    # collection 'foo_vars' should contain a single element. The collection
    # 'bar_vars' should not be found.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      collection_foo_vars = tf.get_collection("foo_vars")
      self.assertEqual(len(collection_foo_vars), 1)
      self.assertEqual(42, collection_foo_vars[0].eval())

      self.assertEqual(len(tf.get_collection("bar_vars")), 0)

    # Restore the graph with tag "bar", whose variables were not saved. The
    # collection-def exported as part of the meta graph def is updated to
    # reflect the new collection. The value of the variable in the
    # collection-def corresponds to the saved value (from the previous graph
    # with tag "foo").
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      collection_bar_vars = tf.get_collection("bar_vars")
      self.assertEqual(len(collection_bar_vars), 1)
      self.assertEqual(42, collection_bar_vars[0].eval())

      self.assertEqual(len(tf.get_collection("foo_vars")), 0)
Exemple #19
0
  def process(self, inputs):
    # Create a session for every worker only once. The session is not
    # pickleable, so it can't be created at the DoFn constructor.
    if not self.session:
      self.graph = ops.Graph()
      with self.graph.as_default():
        self.session = tf.Session()
        metagraph_def = loader.load(
            self.session, {self.meta_tag}, self.model_dir)
      signature_def = metagraph_def.signature_def[self.meta_signature]

      # inputs
      self.feed_tensors = {
          k: self.graph.get_tensor_by_name(v.name)
          for k, v in signature_def.inputs.items()
      }

      # outputs/predictions
      self.fetch_tensors = {
          k: self.graph.get_tensor_by_name(v.name)
          for k, v in signature_def.outputs.items()
      }

    # Create a feed_dict for a single element.
    feed_dict = {
        tensor: [inputs[key]]
        for key, tensor in self.feed_tensors.items()
        if key in inputs
    }
    results = self.session.run(self.fetch_tensors, feed_dict)

    yield {
        'id': inputs[self.id_key],
        'predictions': results[self.meta_predictions][0].tolist()
    }
def local_predict(args):
  """Runs prediction locally."""

  sess = session.Session()
  _ = loader.load(sess, [tag_constants.SERVING], args.model_dir)

  # get the mappings between aliases and tensor names
  # for both inputs and outputs
  input_alias_map = json.loads(sess.graph.get_collection('inputs')[0])
  output_alias_map = json.loads(sess.graph.get_collection('outputs')[0])
  aliases, tensor_names = zip(*output_alias_map.items())

  for input_file in args.input:
    feed_dict = collections.defaultdict(list)
    for line in tf_record.tf_record_iterator(input_file):
      feed_dict[input_alias_map['examples_bytes']].append(line)

    if args.dry_run:
      print('Feed data dict %s to graph and fetch %s' % (
          feed_dict, tensor_names))
    else:
      result = sess.run(fetches=tensor_names, feed_dict=feed_dict)
      for row in zip(*result):
        print(json.dumps(
            {name: (value.tolist() if getattr(value, 'tolist', None) else value)
             for name, value in zip(aliases, row)}))
  def testAssets(self):
    export_dir = self._get_export_dir("test_assets")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)

      # Build an asset collection.
      ignored_filepath = os.path.join(
          compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
      file_io.write_string_to_file(ignored_filepath, "will be ignored")

      asset_collection = self._build_asset_collection("hello42.txt",
                                                      "foo bar baz",
                                                      "asset_file_tensor")

      builder.add_meta_graph_and_variables(
          sess, ["foo"], assets_collection=asset_collection)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      foo_graph = loader.load(sess, ["foo"], export_dir)
      self._validate_asset_collection(export_dir, foo_graph.collection_def,
                                      "hello42.txt", "foo bar baz",
                                      "asset_file_tensor:0")
      ignored_asset_path = os.path.join(
          compat.as_bytes(export_dir),
          compat.as_bytes(constants.ASSETS_DIRECTORY),
          compat.as_bytes("ignored.txt"))
      self.assertFalse(file_io.file_exists(ignored_asset_path))
  def testDuplicateAssets(self):
    export_dir = self._get_export_dir("test_duplicate_assets")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)

      # Build an asset collection with `foo.txt` that has `foo` specific
      # content.
      asset_collection = self._build_asset_collection("foo.txt", "content_foo",
                                                      "asset_file_tensor")

      # Add the asset collection as part of the graph with tag "foo".
      builder.add_meta_graph_and_variables(
          sess, ["foo"], assets_collection=asset_collection)

    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)

      # Build an asset collection with `foo.txt` that has `bar` specific
      # content.
      asset_collection = self._build_asset_collection("foo.txt", "content_bar",
                                                      "asset_file_tensor")

      # Add the asset collection as part of the graph with tag "bar".
      builder.add_meta_graph(["bar"], assets_collection=asset_collection)

    # Save the SavedModel to disk.
    builder.save()

    # Check assets restored for graph with tag "foo".
    with self.test_session(graph=ops.Graph()) as sess:
      foo_graph = loader.load(sess, ["foo"], export_dir)
      self._validate_asset_collection(export_dir, foo_graph.collection_def,
                                      "foo.txt", "content_foo",
                                      "asset_file_tensor:0")

    # Check assets restored for graph with tag "bar".
    with self.test_session(graph=ops.Graph()) as sess:
      bar_graph = loader.load(sess, ["bar"], export_dir)

      # Validate the assets for `bar` graph. `foo.txt` should contain the
      # original contents corresponding to `foo` graph since an asset with the
      # same name across multiple graphs is only stored the first time
      self._validate_asset_collection(export_dir, bar_graph.collection_def,
                                      "foo.txt", "content_foo",
                                      "asset_file_tensor:0")
Exemple #23
0
  def test_scaffold_is_used_for_local_init(self):
    tmpdir = tempfile.mkdtemp()

    def _model_fn_scaffold(features, labels, mode):
      _, _ = features, labels
      my_int = variables.Variable(1, name='my_int',
                                  collections=[ops.GraphKeys.LOCAL_VARIABLES])
      scores = constant_op.constant([3.])
      with ops.control_dependencies(
          [variables.local_variables_initializer(),
           data_flow_ops.tables_initializer()]):
        assign_op = state_ops.assign(my_int, 12345)

      # local_initSop must be an Operation, not a Tensor.
      custom_local_init_op = control_flow_ops.group(assign_op)
      return model_fn_lib.EstimatorSpec(
          mode=mode,
          predictions=constant_op.constant([[1.]]),
          loss=constant_op.constant(0.),
          train_op=constant_op.constant(0.),
          scaffold=training.Scaffold(local_init_op=custom_local_init_op),
          export_outputs={'test': export_output.ClassificationOutput(scores)})

    est = estimator.Estimator(model_fn=_model_fn_scaffold)
    est.train(dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(export_dir_base,
                                       serving_input_receiver_fn)

    # Restore, to validate that the custom local_init_op runs.
    with ops.Graph().as_default() as graph:
      with session.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        my_int = graph.get_tensor_by_name('my_int:0')
        my_int_value = sess.run(my_int)
        self.assertEqual(12345, my_int_value)
  def test_export_savedmodel_with_saveables_proto_roundtrip(self):
    tmpdir = tempfile.mkdtemp()
    est = estimator.Estimator(
        model_fn=_model_fn_with_saveables_for_export_tests)
    est.train(input_fn=dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(
        export_dir_base, serving_input_receiver_fn)

    # Check that all the files are in the right places.
    self.assertTrue(gfile.Exists(export_dir_base))
    self.assertTrue(gfile.Exists(export_dir))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('saved_model.pb'))))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('variables'))))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('variables/variables.index'))))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('variables/variables.data-00000-of-00001'))))

    # Restore, to validate that the export was well-formed.
    with ops.Graph().as_default() as graph:
      with session.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertTrue('input_example_tensor' in graph_ops)
        self.assertTrue('ParseExample/ParseExample' in graph_ops)
        self.assertTrue('save/LookupTableImport' in graph_ops)

    # Clean up.
    gfile.DeleteRecursively(tmpdir)
  def testOp(self):
    export_dir = os.path.join(
        compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("op"))
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with tf.Session(
        graph=tf.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        v1 = tf.Variable(1, name="v1")
      with sess.graph.device("/cpu:1"):
        v2 = tf.Variable(2, name="v2")

      # v3 is an unsaved variable derived from v1 and v2.  It is used to
      # exercise the ability to run an init op when restoring a graph.
      v3 = tf.Variable(1, name="v3", trainable=False, collections=[])
      assign_v3 = tf.assign(v3, tf.add(v1, v2))
      init_op = tf.group(assign_v3, name="init_op")

      tf.add_to_collection("v", v1)
      tf.add_to_collection("v", v2)
      tf.add_to_collection("v", v3)
      tf.add_to_collection("init_op", init_op)

      sess.run(tf.initialize_all_variables())
      self.assertEqual(1, tf.get_collection("v")[0].eval())
      self.assertEqual(2, tf.get_collection("v")[1].eval())

      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with tf.Session(
        graph=tf.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      loader.load(sess, ["foo"], export_dir)

      # Validate variables, run the init op and verify result.
      self.assertEqual(1, tf.get_collection("v")[0].eval())
      self.assertEqual(2, tf.get_collection("v")[1].eval())
      tf.get_collection("init_op")[0].run()
      self.assertEqual(3, tf.get_collection("v")[2].eval())
  def testVerifySessionGraphUsage(self):
    export_dir = self._get_export_dir("test_verify_session_graph_usage")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])

    # Save the SavedModel to disk.
    builder.save()

    # Build a session and supply it to the load operation.
    sess = session.Session(graph=ops.Graph())
    loader.load(sess, [tag_constants.TRAINING], export_dir)

    # Check the variable within the scope of the session and its graph.
    with sess:
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    def testClearDevices(self):
        export_dir = os.path.join(tf.test.get_temp_dir(), "test_clear_devices")
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Specify a device and save a variable.
        tf.reset_default_graph()
        with tf.Session(target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
            with sess.graph.device("/cpu:0"):
                self._init_and_validate_variable(sess, "v", 42)
                builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING], clear_devices=True)

        # Save the SavedModel to disk.
        builder.save()

        # Restore the graph with a single predefined tag whose variables were saved
        # without any device information.
        with self.test_session(graph=tf.Graph()) as sess:
            loader.load(sess, [tag_constants.TRAINING], export_dir)
            self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
Exemple #28
0
  def testCreateInferenceGraph_DynamicOp(self):
    if not trt_convert.is_tensorrt_enabled():
      return
    trt_convert.enable_test_value()

    tmp_dir = self.get_temp_dir()
    input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
    output_saved_model_dir = os.path.join(tmp_dir, "out_dir2")
    self._WriteInputSavedModel(input_saved_model_dir)
    output_graph_def = trt_convert.create_inference_graph(
        None,
        None,
        max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
        is_dynamic_op=True,
        maximum_cached_engines=2,
        input_saved_model_dir=input_saved_model_dir,
        output_saved_model_dir=output_saved_model_dir,
        session_config=self._GetConfigProto())

    # Test the output GraphDef.
    with ops.Graph().as_default():
      importer.import_graph_def(output_graph_def, name="")
      with self.test_session(config=self._GetConfigProto()) as sess:
        # Run with batch size 1, a new engine is created and cached.
        self._TestRun(sess, 1, True)
        # Run with batch size 2, a new engine is created and cached.
        self._TestRun(sess, 2, True)
        # Run with batch size 3, since the number of cached engines has reached
        # the max, it should evict an old engine and create a new one.
        self._TestRun(sess, 3, True)

    # Test the output SavedModel
    with ops.Graph().as_default():
      with self.test_session(config=self._GetConfigProto()) as sess:
        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
        # Run with batch size 1, a new engine is created and cached.
        self._TestRun(sess, 1, True)
        # Run with batch size 2, a new engine is created and cached.
        self._TestRun(sess, 2, True)
        # Run with batch size 3, since the number of cached engines has reached
        # the max, it should evict an old engine and create a new one.
        self._TestRun(sess, 3, True)
  def testOp(self):
    export_dir = self._get_export_dir("test_op")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with session.Session(
        graph=ops.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        v1 = variables.Variable(1, name="v1")
      with sess.graph.device("/cpu:1"):
        v2 = variables.Variable(2, name="v2")

      # v3 is an unsaved variable derived from v1 and v2.  It is used to
      # exercise the ability to run an init op when restoring a graph.
      v3 = variables.Variable(1, name="v3", trainable=False, collections=[])
      assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
      init_op = control_flow_ops.group(assign_v3, name="init_op")

      ops.add_to_collection("v", v1)
      ops.add_to_collection("v", v2)
      ops.add_to_collection("v", v3)
      ops.add_to_collection("init_op", init_op)

      sess.run(variables.global_variables_initializer())
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())

      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with session.Session(
        graph=ops.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      loader.load(sess, ["foo"], export_dir)

      # Validate variables, run the init op and verify result.
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      ops.get_collection("init_op")[0].run()
      self.assertEqual(3, ops.get_collection("v")[2].eval())
  def test_export_savedmodel_with_resource(self):
    tmpdir = tempfile.mkdtemp()
    est, serving_input_fn = _build_estimator_for_resource_export_test()

    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(export_dir_base, serving_input_fn)

    self.assertTrue(gfile.Exists(export_dir_base))
    self.assertTrue(gfile.Exists(export_dir))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir), compat.as_bytes(
                    'saved_model.pb'))))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir), compat.as_bytes('variables'))))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes('variables/variables.index'))))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes('variables/variables.data-00000-of-00001'))))

    # Restore, to validate that the export was well-formed.
    with ops.Graph().as_default() as graph:
      with session_lib.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertTrue('input_example_tensor' in graph_ops)
        self.assertTrue('ParseExample/ParseExample' in graph_ops)
        self.assertTrue('LookupTableModel' in graph_ops)
        self.assertFalse('LookupTableTrainingState' in graph_ops)

    # cleanup
    gfile.DeleteRecursively(tmpdir)
  def testInconsistentConsumerDefaultAttrs(self):
    export_dir = self._get_export_dir(
        "test_strip_default_attrs_no_consumer_defaults")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Add a graph with a single variable and a test op with a defaultless
    # float32 attr, "test_attr".
    with session.Session(graph=ops.Graph()) as sess:
      variables.Variable(1.0, dtype=dtypes.float64, name="var")
      test_ops.test_attr(T=dtypes.float32, name="test_attr")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Rewrite the SavedModel to remove the T attr from "test_attr".
    saved_model_file = os.path.join(
        export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
    with open(saved_model_file) as f:
      original_saved_model = f.read()

    no_attr_saved_model = original_saved_model.replace("""
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }""", "")
    with open(saved_model_file, "w") as f:
      f.write(no_attr_saved_model)

    # Loading the SavedModel via the loader must fail because the SavedModel
    # does not have any attr values for the "TestAttr" node, and there is no
    # default specified in the TestAttr OpDef.
    sess = session.Session(graph=ops.Graph())
    if ops._USE_C_API:
      error_message = "NodeDef missing attr 'T' from Op<name=TestAttr"
    else:
      error_message = ("Expected one attr with name .*T(out)?.* in name: "
                       "\"test_attr\".*")
    with self.assertRaisesRegexp(ValueError, error_message):
      loader.load(sess, ["foo"], export_dir)

    # Rewrite the SavedModel to change the type of the T attr in "test_attr"
    bad_type_saved_model = original_saved_model.replace("""
      attr {
        key: "T"
        value {
          type: DT_FLOAT
        }
      }""", """
      attr {
        key: "T"
        value {
          type: DT_DOUBLE
        }
      }""")
    with open(saved_model_file, "w") as f:
      f.write(bad_type_saved_model)

    # Loading the SavedModel via the loader must fail because there is no
    # OpKernel registered to handle T = double.
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        ".*No OpKernel was registered to support Op \'TestAttr\' with these "
        "attrs..*"):
      loader.load(sess, ["foo"], export_dir)
Exemple #32
0
def convert(saved_model_dir,
            output_tflite=None,
            output_arrays=None,
            tag_set=None,
            signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
            batch_size=1):
  """Convert a savedmodel to tflite flatbuffer.

  Args:
    saved_model_dir: Saved model directory to convert.
    output_tflite: File path to write result flatbuffer.
    output_arrays: List of output tensor names, the default value is None, which
      means conversion keeps all output tensors. This is also used to filter
      tensors that are from Op currently not supported in tflite, e.g., Argmax).
    tag_set: This is the set of tags to get meta_graph_def in saved_model.
    signature_key: This is the signature key to extract inputs, outputs.
    batch_size: If input tensor shape has None at first dimension,
      e.g. (None,224,224,3), replace None with batch_size.

  Returns:
    The converted data. For example if tflite was the destination, then
    this will be a tflite flatbuffer in a bytes array.

  Raises:
    ValueError: If tag_set does not indicate any meta_graph_def in saved_model,
      or signature_key is not in relevant meta_graph_def,
      or input shape has None beyond 1st dimension, e.g., (1,None, None, 3),
      or given output_arrays are not valid causing empty outputs.
  """
  if tag_set is None:
    tag_set = set([tag_constants.SERVING])

  meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
  signature_def = get_signature_def(meta_graph, signature_key)
  inputs, outputs = get_inputs_outputs(signature_def)

  graph = ops.Graph()
  with session.Session(graph=graph) as sess:

    loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)

    in_tensors = [graph.get_tensor_by_name(input_) for input_ in inputs]

    # Users can use output_arrays to filter output tensors for conversion.
    # If output_arrays is None, we keep all output tensors. In future, we may
    # use tflite supported Op list and check whether op is custom Op to
    # automatically filter output arrays.
    # TODO(zhixianyan): Use tflite supported Op list to filter outputs.
    if output_arrays is not None:
      output_arrays = output_arrays.split(",")
      out_tensors = [
          graph.get_tensor_by_name(output)
          for output in outputs
          if output.split(":")[0] in output_arrays
      ]
    else:
      out_tensors = [graph.get_tensor_by_name(output) for output in outputs]

    output_names = [node.split(":")[0] for node in outputs]

    if not out_tensors:
      raise ValueError(
          "No valid output tensors for '{}', possible values are '{}'".format(
              output_arrays, output_names))

    frozen_graph_def = tf_graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), output_names)

    # Toco requires fully defined tensor shape, for input tensor with None in
    # their shape, e.g., (None, 224, 224, 3), we need to replace first None with
    # a given batch size. For shape with more None, e.g. (None, None, None, 3),
    # still be able to replace and convert, but require further investigation.
    # TODO(zhixianyan): Add supports for input tensor with more None in shape.
    for i in range(len(in_tensors)):
      shape = in_tensors[i].get_shape().as_list()
      if shape[0] is None:
        shape[0] = batch_size
      if None in shape[1:]:
        raise ValueError(
            "Only support None shape at 1st dim as batch_size. But tensor "
            "'{}' 's shape '{}' has None at other dimension. ".format(
                inputs[i], shape))
      in_tensors[i].set_shape(shape)

    result = lite.toco_convert(frozen_graph_def, in_tensors, out_tensors)

    if output_tflite is not None:
      with gfile.Open(output_tflite, "wb") as f:
        f.write(result)
      logging.info("Successfully converted to: %s", output_tflite)

    return result
  def test_export_savedmodel_assets(self):
    tmpdir = tempfile.mkdtemp()
    est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
    est.train(input_fn=dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Create a fake asset.
    vocab_file_name = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file'))
    vocab_file = gfile.GFile(vocab_file_name, mode='w')
    vocab_file.write(_VOCAB_FILE_CONTENT)
    vocab_file.close()

    # hack in an op that uses the asset, in order to test asset export.
    # this is not actually valid, of course.
    def serving_input_receiver_with_asset_fn():
      features, receiver_tensor = serving_input_receiver_fn()
      filename = ops.convert_to_tensor(vocab_file_name,
                                       dtypes.string,
                                       name='asset_filepath')
      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
      features['bogus_filename'] = filename

      return export.ServingInputReceiver(features, receiver_tensor)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(
        export_dir_base, serving_input_receiver_with_asset_fn)

    # Check that the asset files are in the right places.
    expected_vocab_file_name = os.path.join(
        compat.as_bytes(export_dir), compat.as_bytes('assets/my_vocab_file'))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir), compat.as_bytes('assets'))))
    self.assertTrue(gfile.Exists(expected_vocab_file_name))
    self.assertEqual(
        compat.as_bytes(_VOCAB_FILE_CONTENT),
        compat.as_bytes(gfile.GFile(expected_vocab_file_name).read()))

    # Restore, to validate that the export was well-formed.
    with ops.Graph().as_default() as graph:
      with session.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        assets = [
            x.eval()
            for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
        ]
        self.assertItemsEqual([vocab_file_name], assets)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertTrue('input_example_tensor' in graph_ops)
        self.assertTrue('ParseExample/ParseExample' in graph_ops)
        self.assertTrue('asset_filepath' in graph_ops)
        self.assertTrue('weight' in graph_ops)

    # cleanup
    gfile.DeleteRecursively(tmpdir)
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)

  Returns:
    Location of the output_graph_def.
  """
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not checkpoint_management.checkpoint_exists(input_checkpoint)):
    raise ValueError("Input checkpoint '" + input_checkpoint +
                     "' doesn't exist!")

  if not output_node_names:
    raise ValueError(
        "You need to supply the name of a node to --output_node_names.")

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      #reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
      reader = NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_partition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_partition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          raise ValueError(
              "Models containing partition variables cannot be converted "
              "from checkpoint files. Please pass in a SavedModel using "
              "the flag --input_saved_model_dir.")
        # Models that have been frozen previously do not contain Variables.
        elif _has_no_variables(sess):
          raise ValueError(
              "No variables were found in this model. It is likely the model "
              "was frozen previously. You cannot freeze a graph twice.")
          return 0
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
    def testTags(self):
        export_dir = os.path.join(test.get_temp_dir(), "test_tags")
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Graph with a single variable. SavedModel invoked to:
        # - add with weights.
        # - a single tag (from predefined constants).
        with self.test_session(graph=ops.Graph()) as sess:
            self._init_and_validate_variable(sess, "v", 42)
            builder.add_meta_graph_and_variables(sess,
                                                 [tag_constants.TRAINING])

        # Graph that updates the single variable. SavedModel invoked to:
        # - simply add the model (weights are not updated).
        # - a single tag (from predefined constants).
        with self.test_session(graph=ops.Graph()) as sess:
            self._init_and_validate_variable(sess, "v", 43)
            builder.add_meta_graph([tag_constants.SERVING])

        # Graph that updates the single variable. SavedModel is invoked:
        # - to add the model (weights are not updated).
        # - multiple custom tags.
        with self.test_session(graph=ops.Graph()) as sess:
            self._init_and_validate_variable(sess, "v", 44)
            builder.add_meta_graph(["foo", "bar"])

        # Save the SavedModel to disk.
        builder.save()

        # Restore the graph with a single predefined tag whose variables were saved.
        with self.test_session(graph=ops.Graph()) as sess:
            loader.load(sess, [tag_constants.TRAINING], export_dir)
            self.assertEqual(
                42,
                ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

        # Restore the graph with a single predefined tag whose variables were not
        # saved.
        with self.test_session(graph=ops.Graph()) as sess:
            loader.load(sess, [tag_constants.SERVING], export_dir)
            self.assertEqual(
                42,
                ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

        # Restore the graph with multiple tags. Provide duplicate tags to test set
        # semantics.
        with self.test_session(graph=ops.Graph()) as sess:
            loader.load(sess, ["foo", "bar", "foo"], export_dir)
            self.assertEqual(
                42,
                ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

        # Try restoring a graph with a non-existent tag. This should yield a runtime
        # error.
        with self.test_session(graph=ops.Graph()) as sess:
            self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"],
                              export_dir)

        # Try restoring a graph where a subset of the tags match. Since tag matching
        # for meta graph defs follows "all" semantics, this should yield a runtime
        # error.
        with self.test_session(graph=ops.Graph()) as sess:
            self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"],
                              export_dir)
Exemple #36
0
def _run_model(iterator, args):
  """Run single-node inferencing on a checkpoint/saved_model using input tensors obtained from a Spark partition iterator and returning output tensors."""
  single_node_env(args)

  logging.info("===== input_mapping: {}".format(args.input_mapping))
  logging.info("===== output_mapping: {}".format(args.output_mapping))
  input_tensor_names = [ tensor for col,tensor in sorted(args.input_mapping.items()) ]
  output_tensor_names = [ tensor for tensor,col in sorted(args.output_mapping.items()) ]

  # if using a signature_def_key, get input/output tensor info from the requested signature
  if args.signature_def_key:
    assert args.export_dir, "Inferencing with signature_def_key requires --export_dir argument"
    logging.info("===== loading meta_graph_def for tag_set ({0}) from saved_model: {1}".format(args.tag_set, args.export_dir))
    meta_graph_def = get_meta_graph_def(args.export_dir, args.tag_set)
    signature = signature_def_utils.get_signature_def_by_key(meta_graph_def, args.signature_def_key)
    logging.info("signature: {}".format(signature))
    inputs_tensor_info = signature.inputs
    logging.info("inputs_tensor_info: {0}".format(inputs_tensor_info))
    outputs_tensor_info = signature.outputs
    logging.info("outputs_tensor_info: {0}".format(outputs_tensor_info))

  result = []
  with tf.Session(graph=ops_lib.Graph()) as sess:
    if args.export_dir:
      assert args.tag_set, "Inferencing from a saved_model requires --tag_set"
      # load graph from a saved_model
      logging.info("===== restoring from saved_model: {}".format(args.export_dir))
      loader.load(sess, args.tag_set.split(','), args.export_dir)
    elif args.model_dir:
      # load graph from a checkpoint
      ckpt = tf.train.latest_checkpoint(args.model_dir)
      assert ckpt, "Invalid model checkpoint path: {}".format(args.model_dir)
      logging.info("===== restoring from checkpoint: {}".format(ckpt + ".meta"))
      saver = tf.train.import_meta_graph(ckpt + ".meta", clear_devices=True)
      saver.restore(sess, ckpt)
    else:
      raise Exception("Inferencing requires either --model_dir or --export_dir argument")

    # get list of input/output tensors (by name)
    if args.signature_def_key:
      input_tensors = [inputs_tensor_info[t].name for t in input_tensor_names]
      output_tensors = [outputs_tensor_info[output_tensor_names[0]].name]
    else:
      input_tensors = [t + ':0' for t in input_tensor_names]
      output_tensors = [t + ':0' for t in output_tensor_names]

    logging.info("input_tensors: {0}".format(input_tensors))
    logging.info("output_tensors: {0}".format(output_tensors))

    # feed data in batches and return output tensors
    for tensors in yield_batch(iterator, args.batch_size, len(input_tensor_names)):
      inputs_feed_dict = {}
      for i in range(len(input_tensors)):
        inputs_feed_dict[input_tensors[i]] = tensors[i]

      outputs = sess.run(output_tensors, feed_dict=inputs_feed_dict)
      lengths = [ len(output) for output in outputs ]
      input_size = len(tensors[0])
      assert all([ l == input_size for l in lengths ]), "Output array sizes {} must match input size: {}".format(lengths, input_size)
      python_outputs = [ output.tolist() for output in outputs ]      # convert from numpy to standard python types
      result.extend(zip(*python_outputs))                             # convert to an array of tuples of "output columns"
  return result
def main():

    args = parse_args()

    np.random.seed(args.random_seed)
    tf.set_random_seed(args.random_seed)

    use_crf = args.use_crf
    print('Use CRF: ', use_crf)
    BATCH_SIZE = args.batch_size

    print('Setting up...')
    train_filenames = pd.read_csv(TRAIN_FILENAME,
                                  dtype=object,
                                  keep_default_na=False,
                                  na_values=[]).as_matrix()
    train0_filenames = pd.read_csv(TRAIN0_FILENAME,
                                   dtype=object,
                                   keep_default_na=False,
                                   na_values=[]).as_matrix()
    val_filenames = pd.read_csv(VAL_FILENAME,
                                dtype=object,
                                keep_default_na=False,
                                na_values=[]).as_matrix()
    test_filenames = pd.read_csv(TEST_FILENAME,
                                 dtype=object,
                                 keep_default_na=False,
                                 na_values=[]).as_matrix()
    subsets = ['train', 'val', 'test']
    image_mean = IMAGE_MEAN
    image_std = IMAGE_STD
    image_mean_tensor = tf.constant(image_mean, dtype=tf.float32)
    image_std_tensor = tf.constant(image_std, dtype=tf.float32)

    print('num_train: ', NUM_TRAIN_SAMPLES)
    STEPS_ONE_EPOCH = int(NUM_TRAIN_SAMPLES / BATCH_SIZE)
    max_steps = STEPS_ONE_EPOCH * 21
    DECAY_STEPS = STEPS_ONE_EPOCH * 5
    print('steps one epoch: ', STEPS_ONE_EPOCH)
    print('decay steps: ', DECAY_STEPS)

    train_filenames_tfrecords = [
        f[1] + '/data2.tfrecords' for f in train_filenames
    ]
    val_filenames_tfrecords = [
        f[1] + '/data2.tfrecords' for f in val_filenames
    ]
    train_images, train_labels, train_case_names = dataset_input_from_tfrecords(
        train_filenames_tfrecords,
        batch_size=BATCH_SIZE,
        num_epochs=5000,
        shuffle=True)
    val_images, val_labels, val_case_names = dataset_input_from_tfrecords(
        val_filenames_tfrecords,
        batch_size=BATCH_SIZE,
        num_epochs=5000,
        shuffle=False)

    phase_train = tf.placeholder(tf.bool, name='phase_train')
    global_step = tf.Variable(0, trainable=False)
    image_node = tf.placeholder(tf.float32,
                                shape=[None, DEPTH, HEIGHT, WIDTH, 2])
    label_node = tf.placeholder(tf.int32,
                                shape=[None, DEPTH, HEIGHT, WIDTH, 2])

    if args.norm_type == 'nonorm':
        image_node_new = image_node
    elif args.norm_type == 'globalnorm_mean':
        image_node_new = image_node - image_mean_tensor
    elif args.norm_type == 'globalnorm_meanstd':
        image_node_new = image_node - image_mean_tensor
        image_node_new /= image_std_tensor
    elif args.norm_type == 'instancenorm_mean':
        image_node_new = tf.map_fn(
            lambda frame: frame - tf.reduce_mean(
                frame, axis=[0, 1, 2], keep_dims=True), image_node)
    elif args.norm_type == 'instancenorm_meanstd':
        batch_mean, batch_var = tf.nn.moments(image_node,
                                              axes=[1, 2, 3],
                                              keep_dims=True)
        image_node_new = (image_node - batch_mean) / tf.sqrt(batch_var + 1e-6)

    if args.net_type == 'myunet3d_bn_crf':
        from myunet3d_basic import myunet3d_bn_crf
        net_output_ops = myunet3d_bn_crf(
            name='ct' if args.feat_index == 0 else 'pt',
            inputs=image_node_new[..., args.feat_index][..., tf.newaxis],
            num_classes=NUM_CLASSES,
            phase_train=phase_train,
            use_bias=True,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            bias_initializer=tf.constant_initializer(value=0.1),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            bias_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            use_crf=args.use_crf,
            args=args)
    elif args.net_type == 'myunet3d_crf':
        from myunet3d_basic import myunet3d_crf
        net_output_ops = myunet3d_crf(
            name='ct' if args.feat_index == 0 else 'pt',
            inputs=image_node_new[..., args.feat_index][..., tf.newaxis],
            num_classes=NUM_CLASSES,
            phase_train=phase_train,
            use_bias=True,
            kernel_initializer=tf.glorot_uniform_initializer(
            ),  #tf.truncated_normal_initializer(stddev=0.01),
            bias_initializer=tf.constant_initializer(value=0.1),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            bias_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            use_crf=args.use_crf,
            args=args)
    elif args.net_type == 'myunet3d_isbi2018':
        from myunet3d_basic import myunet3d_isbi2018
        net_output_ops = myunet3d_isbi2018(
            name='ct' if args.feat_index == 0 else 'pt',
            inputs=image_node_new[..., args.feat_index][..., tf.newaxis],
            num_classes=NUM_CLASSES,
            phase_train=phase_train,
            use_bias=True,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            bias_initializer=tf.constant_initializer(value=0.1),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            bias_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            use_crf=args.use_crf,
            args=args)
    elif args.net_type == 'myunet3d_isbi2018_2':
        from myunet3d_basic import myunet3d_isbi2018_2
        net_output_ops = myunet3d_isbi2018_2(
            name='ct' if args.feat_index == 0 else 'pt',
            inputs=image_node_new[..., args.feat_index][..., tf.newaxis],
            num_classes=NUM_CLASSES,
            phase_train=phase_train,
            use_bias=True,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            bias_initializer=tf.constant_initializer(value=0.1),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            bias_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            use_crf=args.use_crf,
            args=args)
    elif args.net_type == 'myunet3d_improved':
        from myunet3d_improved import myunet3d_improved
        net_output_ops = myunet3d_improved(
            name='ct' if args.feat_index == 0 else 'pt',
            inputs=image_node_new[..., args.feat_index][..., tf.newaxis],
            num_classes=NUM_CLASSES,
            phase_train=phase_train,
            use_bias=True,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            bias_initializer=tf.constant_initializer(value=0.1),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            bias_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            use_crf=args.use_crf,
            args=args)
    elif args.net_type == 'myunet3d_improved_usingresize':
        from myunet3d_improved_usingresize import myunet3d_improved_usingresize
        net_output_ops = myunet3d_improved_usingresize(
            name='ct' if args.feat_index == 0 else 'pt',
            inputs=image_node_new[..., args.feat_index][..., tf.newaxis],
            num_classes=NUM_CLASSES,
            phase_train=phase_train,
            use_bias=True,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            bias_initializer=tf.constant_initializer(value=0.1),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            bias_regularizer=tf.contrib.layers.l2_regularizer(1e-4),
            use_crf=args.use_crf,
            args=args)

    modes = ['ct', 'pt']
    image_node_shaped = tf.reshape(image_node[..., args.feat_index],
                                   [-1, DEPTH, HEIGHT * WIDTH, 1])
    label_node_shaped = tf.cast(
        tf.reshape(label_node[..., args.feat_index],
                   [-1, DEPTH, HEIGHT * WIDTH, 1]), tf.float32)
    for bi in range(BATCH_SIZE):
        tf.summary.image('input_{}_{}'.format(modes[args.feat_index], str(bi)),
                         image_node_shaped[bi][tf.newaxis, ...], 1)
        tf.summary.image('label_{}_{}'.format(modes[args.feat_index], str(bi)),
                         label_node_shaped[bi][tf.newaxis, ...], 1)

    prob_flat = tf.unstack(net_output_ops['y_prob'], num=2,
                           axis=4)  # [-1, 32, 32, 32]
    prob_flat = [
        tf.reshape(v, [-1, DEPTH, HEIGHT * WIDTH]) for v in prob_flat
    ]  #[-1,32,32*32]
    prob_flat = tf.concat(prob_flat, axis=1)  #[-1, 32*16, 32*32]
    prob_flat = tf.cast(
        tf.reshape(prob_flat, [-1, 2 * DEPTH, HEIGHT * WIDTH, 1]), tf.float32)
    pred_flat = tf.cast(
        tf.reshape(net_output_ops['y_'], [-1, DEPTH, HEIGHT * WIDTH, 1]),
        tf.float32)

    for bi in range(BATCH_SIZE):
        tf.summary.image('prob_{}_{}'.format(modes[args.feat_index], str(bi)),
                         prob_flat[bi][tf.newaxis, ...], 1)
        tf.summary.image('pred_{}_{}'.format(modes[args.feat_index], str(bi)),
                         pred_flat[bi][tf.newaxis, ...], 1)

    pred_op = net_output_ops['y_']
    prob_op = net_output_ops['y_prob']
    print('pred_op shape: ', pred_op.shape)
    print('prob_op shape: ', prob_op.shape)
    dice_op = dice_tf(labels=label_node[..., args.feat_index],
                      logits=prob_op[..., 1])

    if args.use_crf:
        prob_crf_op = net_output_ops['y_prob_crf']
        prob_crf_flat = tf.unstack(prob_crf_op, num=2,
                                   axis=4)  # [-1, 32, 32, 32]
        prob_crf_flat = [
            tf.reshape(v, [-1, DEPTH, HEIGHT * WIDTH]) for v in prob_crf_flat
        ]  #[-1,32,32*32]
        prob_crf_flat = tf.concat(prob_crf_flat, axis=1)  #[-1, 32*16, 32*32]
        prob_crf_flat = tf.cast(
            tf.reshape(prob_crf_flat, [-1, 2 * DEPTH, HEIGHT * WIDTH, 1]),
            tf.float32)

        pred_crf_flat = tf.cast(
            tf.reshape(net_output_ops['y_crf'],
                       [-1, DEPTH, HEIGHT * WIDTH, 1]), tf.float32)

        for bi in range(BATCH_SIZE):
            tf.summary.image(
                'prob_crf_{}_{}'.format(modes[args.feat_index], str(bi)),
                prob_crf_flat[bi][tf.newaxis, ...], 1)
            tf.summary.image(
                'pred_crf_{}_{}'.format(modes[args.feat_index], str(bi)),
                pred_crf_flat[bi][tf.newaxis, ...], 1)

        pred_crf_op = net_output_ops['y_crf']
        dice_crf_op = dice_tf(labels=label_node[..., args.feat_index],
                              logits=prob_crf_op[..., 1])

    # 2. set up a loss function    # regularization loss
    reg_constant = tf.constant(args.reg_coef, dtype=tf.float32)
    reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    if len(reg_losses) != 0:
        loss_reg_op = tf.add_n(reg_losses)
    else:
        loss_reg_op = tf.constant(0, dtype=tf.float32)
    tf.summary.scalar('loss_reg_{}'.format(modes[args.feat_index]),
                      loss_reg_op)
    # 2. set up a loss function
    if args.loss_type == 'ce':
        ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=net_output_ops['logits'],
            labels=label_node[..., args.feat_index])
        loss_op = tf.reduce_mean(ce)
    elif args.loss_type == 'dice':
        loss_op = dice_loss(labels=label_node[..., args.feat_index],
                            logits=net_output_ops['logits'])
    elif args.loss_type == 'focal':
        loss_op = focal_loss(labels=label_node[..., args.feat_index],
                             logits=net_output_ops['logits'])
    elif args.loss_type == 'bce':
        loss_op = binary_cross_entropy(labels=label_node[..., args.feat_index],
                                       logits=net_output_ops['logits'])
    tf.summary.scalar('loss_{}'.format(modes[args.feat_index]), loss_op)
    total_loss_op = loss_op + tf.multiply(reg_constant, loss_reg_op)
    tf.summary.scalar('total_loss_{}'.format(modes[args.feat_index]),
                      total_loss_op)

    if args.use_crf:
        if args.loss_type == 'ce':
            ce_crf = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=net_output_ops['logits_crf'],
                labels=label_node[..., args.feat_index])
            loss_crf_op = tf.reduce_mean(ce_crf)
        elif args.loss_type == 'dice':
            loss_crf_op = dice_loss(labels=label_node[..., args.feat_index],
                                    logits=net_output_ops['logits_crf'])
        elif args.loss_type == 'focal':
            loss_crf_op = focal_loss(labels=label_node[..., args.feat_index],
                                     logits=net_output_ops['logits_crf'])
        elif args.loss_type == 'bce':
            loss_crf_op = binary_cross_entropy(
                labels=label_node[..., args.feat_index],
                logits=net_output_ops['logits_crf'])
        tf.summary.scalar('loss_crf_{}'.format(modes[args.feat_index]),
                          loss_crf_op)
        total_loss_crf_op = loss_crf_op + tf.multiply(reg_constant,
                                                      loss_reg_op)
        tf.summary.scalar('total_loss_{}_crf'.format(modes[args.feat_index]),
                          total_loss_crf_op)

    if args.use_crf:
        crf_sw_op = tf.get_default_graph().get_tensor_by_name(
            modes[args.feat_index] + '/crf/spatial_weights:0')
        crf_bw_op = tf.get_default_graph().get_tensor_by_name(
            modes[args.feat_index] + '/crf/bilateral_weights:0')
        crf_cm_op = tf.get_default_graph().get_tensor_by_name(
            modes[args.feat_index] + '/crf/compatibility_matrix:0')

        noncrf_vars = []
        crf_vars = []
        for v in tf.trainable_variables():
            if 'crf' in v.name:
                crf_vars.append(v)
            else:
                noncrf_vars.append(v)

    if args.lr_policy == 'constant':
        lr_op = args.base_lr
    elif args.lr_policy == 'piecewise':
        #### piecewise learning rate
        boundaries = [i * DECAY_STEPS for i in [1, 2, 3, 4, 5]]
        staged_lr = [
            args.base_lr * x for x in [1, 0.5, 0.25, 0.125, 0.0625, 0.03125]
        ]

        lr_op = tf.train.piecewise_constant(global_step, boundaries, staged_lr)
    elif args.lr_policy == 'expdecay':
        learning_rate_decay_factor = 0.5
        decay_steps = 1000

        # Decay the learning rate exponentially based on the number of steps.
        lr_op = tf.train.exponential_decay(args.base_lr,
                                           global_step,
                                           decay_steps,
                                           learning_rate_decay_factor,
                                           staircase=True)

    tf.summary.scalar('learning_rate_{}'.format(modes[args.feat_index]), lr_op)

    if args.opt_type == 'gd':
        optimiser = tf.train.GradientDescentOptimizer(learning_rate=lr_op)
    elif args.opt_type == 'gd_diff_lr':
        optimiser = tf.train.GradientDescentOptimizer(learning_rate=lr_op)
        optimiser_crf = tf.train.GradientDescentOptimizer(learning_rate=lr_op *
                                                          args.crf_lr_scale)
    elif args.opt_type == 'adam_diff_lr':
        optimiser = tf.train.AdamOptimizer(learning_rate=lr_op)
        optimiser_crf = tf.train.AdamOptimizer(learning_rate=lr_op *
                                               args.crf_lr_scale)
    elif args.opt_type == 'adam_gd_diff_lr':
        optimiser = tf.train.AdamOptimizer(learning_rate=lr_op)
        optimiser_crf = tf.train.GradientDescentOptimizer(learning_rate=lr_op *
                                                          args.crf_lr_scale)
    elif args.opt_type == 'momentum':
        optimiser = tf.train.MomentumOptimizer(learning_rate=lr_op,
                                               momentum=0.9)
    elif args.opt_type == 'adam':
        optimiser = tf.train.AdamOptimizer(learning_rate=lr_op)
    elif args.opt_type == 'adadelta':
        optimiser = tf.train.AdadeltaOptimizer(learning_rate=lr_op)
    elif args.opt_type == 'adagrad':
        optimiser = tf.train.AdagradOptimizer(learning_rate=lr_op)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if 'diff_lr' in args.opt_type:
        grads = tf.gradients(total_loss_crf_op, noncrf_vars + crf_vars)
        grads1 = grads[:len(noncrf_vars)]
        grads2 = grads[len(noncrf_vars):]
        with tf.control_dependencies(update_ops):
            train_op1 = optimiser.apply_gradients(zip(grads1, noncrf_vars))
        train_op2 = optimiser_crf.apply_gradients(zip(grads2, crf_vars))
        train_op = tf.group(train_op1, train_op2)
    else:
        with tf.control_dependencies(update_ops):
            if args.use_crf:
                train_op = optimiser.minimize(total_loss_crf_op,
                                              global_step=global_step)
            else:
                train_op = optimiser.minimize(total_loss_op,
                                              global_step=global_step)

    log_postfix = '{}_{}x{}x{}_f{}_{}_opt{}_lr{}{}_b{}_loss{}_crf{}_reg{}_rs{}'.format(
        args.net_type, HEIGHT, WIDTH, DEPTH, modes[args.feat_index],
        args.norm_type, args.opt_type, args.lr_policy, args.base_lr,
        args.batch_size, args.loss_type, args.use_crf, args.reg_coef,
        args.random_seed)
    if args.use_crf:
        log_postfix = '{}/{}_{}_{}_{}_{}_{}_lr_scale_{}'.format(
            log_postfix, args.sw_weight, args.bw_weight, args.cm_weight,
            args.theta_alpha, args.theta_beta, args.theta_gamma,
            args.crf_lr_scale)
    log_dir = '{}_{}'.format(args.log_dir, log_postfix)

    if args.restore_ckpt == "" and args.action == 'train':
        if tf.gfile.Exists(log_dir):
            tf.gfile.DeleteRecursively(log_dir)
        tf.gfile.MakeDirs(log_dir)

    # set up log file
    current_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S.%f")
    log_filename = '{}/log_{}.txt'.format(log_dir, current_time)
    cmd_filename = '{}/log_curve_{}.sh'.format(log_dir, current_time)
    log_file_handle = open(log_filename, 'w')
    with open(cmd_filename, 'w') as cmd_file_handle:
        cmd_file_handle.write(
            'python plot_learning_curves_single.py {} {} {}\n'.format(
                log_dir, log_postfix, log_filename))

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=0)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    if args.action == 'train':
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    if args.restore_ckpt != "" and args.action == 'train':
        variables = tf.global_variables()
        restore_vars = []
        restore_var_names = [line.strip() for line in \
                             open('trainable_variables_{}_only.txt'.format(modes[args.feat_index]),'r').readlines()]
        for v in variables:
            found = False
            for name in restore_var_names:
                if name in v.name:
                    found = True
                    break
            if found:
                restore_vars.append(v)

        for v in restore_vars:
            print(v.name)

        restore_saver = tf.train.Saver(restore_vars)
        restore_saver.restore(sess, args.restore_ckpt)

    if args.restore_ckpt != "" and args.action == 'test':
        saver1 = tf.train.import_meta_graph('{}.meta'.format(
            args.restore_ckpt))
        saver1.restore(sess, '{}'.format(args.restore_ckpt))

    if args.restore_from_saved_model != "":

        pretrained_variables = {}
        with tf.Graph().as_default():
            # restore_from = '/media/ubuntu/working/petct_cnn/logs_dltk_ctpt_1e-4_v1(dont delete)/ct_only/1513006564'
            with tf.Session() as sess1:
                graph_old = loader.load(sess1, ['serve'],
                                        args.restore_from_saved_model)
                for v in tf.global_variables():
                    pretrained_variables[v.name] = v.eval()

        count = 0
        for v in tf.global_variables():
            for name in pretrained_variables.keys():
                if modes[args.feat_index] + '/' + name == v.name:
                    print('{}: {} ==> {}'.format(count, name, v.name))
                    sess.run(v.assign(pretrained_variables[name]))
                    count += 1
                    break
        print(count, ' variables copied from previous model!')

    if args.action == 'test':
        if args.save_prob_dir != '' and not os.path.exists(args.save_prob_dir):
            os.makedirs(args.save_prob_dir)
        if args.save_dir != '' and not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        for idx, filename in enumerate(args.test_filenames.split(',')):
            temp = pd.read_csv(DATA_ROOT + '/' + filename,
                               dtype=object,
                               keep_default_na=False,
                               na_values=[]).as_matrix()
            if idx == 0:
                test_filenames = temp
            else:
                test_filenames = np.concatenate([test_filenames, temp], axis=0)

        dice_val = []
        if args.use_crf:
            dice_val_crf = []
        alldata = {}
        for f in test_filenames:
            # for f in np.concatenate([val_filenames, test_filenames], axis=0):
            subject_id = f[0]
            img_fn = f[1]
            case_name = img_fn.split('/')[-1]
            # if '2141' in case_name:
            #     continue

            # Read the image nii with sitk and keep the pointer to the sitk.Image
            # of an input
            ct_sitk = sitk.ReadImage(
                str(os.path.join(img_fn, 'InputCT_ROI.nii.gz')))
            ct = sitk.GetArrayFromImage(ct_sitk).astype((np.float32))
            pt_sitk = sitk.ReadImage(
                str(os.path.join(img_fn, 'InputPET_SUV_ROI.nii.gz')))
            pt = sitk.GetArrayFromImage(pt_sitk).astype((np.float32))
            lbl_ct = sitk.GetArrayFromImage(
                sitk.ReadImage(
                    str(
                        os.path.join(
                            img_fn, 'GTV_Primary_ROI_CT{}.nii.gz'.format(
                                GT_POSTFIX))))).astype(np.uint8)
            lbl_pt = sitk.GetArrayFromImage(
                sitk.ReadImage(
                    str(
                        os.path.join(
                            img_fn, 'GTV_Primary_ROI_PET{}.nii.gz'.format(
                                GT_POSTFIX))))).astype(np.uint8)

            ct[ct > 200.] = 200.
            ct[ct < -500.] = -500.
            ct = 255 * (ct + 500) / (700.)

            pt[pt < 0.01] = 0.01
            pt[pt > 20.] = 20.
            pt = 255 * (pt - 0.01) / (19.99)

            image = np.concatenate([ct[..., np.newaxis], pt[..., np.newaxis]],
                                   axis=3)
            label = np.concatenate(
                [lbl_ct[..., np.newaxis], lbl_pt[..., np.newaxis]], axis=3)

            #             if image_mean != None:
            #                image -= np.reshape(image_mean, [1,1,1,2])
            #             if image_std != None:
            #                image /= np.reshape(image_std, [1,1,1,2])

            if args.use_crf:
                pred, pred_crf, prob, prob_crf = sess.run(
                    [pred_op, pred_crf_op, prob_op, prob_crf_op],
                    feed_dict={
                        image_node: image[np.newaxis, ...],
                        label_node: label[np.newaxis, ...],
                        phase_train: False
                    })
            else:
                pred, prob = sess.run(
                    [pred_op, prob_op],
                    feed_dict={
                        image_node: image[np.newaxis, ...],
                        label_node: label[np.newaxis, ...],
                        phase_train: False
                    })
            dice_val_ = computeDice(label[..., args.feat_index], pred[0])
            dice_val.append(dice_val_)
            if args.use_crf:
                dice_val_crf_ = computeDice(label[..., args.feat_index],
                                            pred_crf[0])
                dice_val_crf.append(dice_val_crf_)

            if args.save_prob_dir != '':
                alldata[case_name] = {}
                alldata[case_name]['image'] = image
                alldata[case_name]['label'] = label
                alldata[case_name]['pred'] = pred[0]
                alldata[case_name]['prob'] = prob[0]
                #alldata[case_name]['ct_sitk'] = ct_sitk
            log_file_handle.write('{} {} {}'.format(
                case_name, dice_val_,
                dice_val_crf_ if args.use_crf == 1 else ''))

            if args.save_dir != '':
                case_save_dir = '{}/{}'.format(args.save_dir, case_name)
                if not os.path.exists(case_save_dir):
                    os.makedirs(case_save_dir)

                if args.use_crf == 1:
                    new_sitk_ct = sitk.GetImageFromArray(pred[0].astype(
                        np.int32))
                    new_sitk_ct.CopyInformation(ct_sitk)
                    sitk.WriteImage(
                        new_sitk_ct,
                        str('{}/crf1_pred_{}_before.nii.gz'.format(
                            case_save_dir,
                            'ct' if args.feat_index == 0 else 'pt')))
                    new_sitk_ct = sitk.GetImageFromArray(pred_crf[0].astype(
                        np.int32))
                    new_sitk_ct.CopyInformation(ct_sitk)
                    sitk.WriteImage(
                        new_sitk_ct,
                        str('{}/crf1_pred_{}_after.nii.gz'.format(
                            case_save_dir,
                            'ct' if args.feat_index == 0 else 'pt')))
                    new_sitk_ct = sitk.GetImageFromArray(
                        prob[0][..., 1].astype(np.float32))
                    new_sitk_ct.CopyInformation(pt_sitk)
                    sitk.WriteImage(
                        new_sitk_ct,
                        str('{}/crf1_prob_{}_before.nii.gz'.format(
                            case_save_dir,
                            'ct' if args.feat_index == 0 else 'pt')))
                    new_sitk_ct = sitk.GetImageFromArray(
                        prob_crf[0][..., 1].astype(np.float32))
                    new_sitk_ct.CopyInformation(pt_sitk)
                    sitk.WriteImage(
                        new_sitk_ct,
                        str('{}/crf1_prob_{}_after.nii.gz'.format(
                            case_save_dir,
                            'ct' if args.feat_index == 0 else 'pt')))
                else:
                    new_sitk_ct = sitk.GetImageFromArray(pred[0].astype(
                        np.int32))
                    new_sitk_ct.CopyInformation(ct_sitk)
                    sitk.WriteImage(
                        new_sitk_ct,
                        str('{}/crf0_pred_{}.nii.gz'.format(
                            case_save_dir,
                            'ct' if args.feat_index == 0 else 'pt')))
                    new_sitk_ct = sitk.GetImageFromArray(
                        prob[0][..., 1].astype(np.float32))
                    new_sitk_ct.CopyInformation(pt_sitk)
                    sitk.WriteImage(
                        new_sitk_ct,
                        str('{}/crf0_prob_{}.nii.gz'.format(
                            case_save_dir,
                            'ct' if args.feat_index == 0 else 'pt')))

        if args.save_prob_dir != '':
            alldata['val_filenames'] = val_filenames
            alldata['test_filenames'] = test_filenames
            pickle.dump(alldata, open(args.save_prob_dir + '/alldata.p', "wb"))

        log_file_handle.write('Mean {} Dice(Before CRF): {}'.format(
            modes[args.feat_index], np.mean(np.array(dice_val))))
        if args.use_crf:
            log_file_handle.write('Mean {} Dice(After  CRF): {}'.format(
                modes[args.feat_index], np.mean(np.array(dice_val_crf))))

        return

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    step = 0
    if args.step > 0:
        step = args.step
    while not coord.should_stop():

        if step % STEPS_ONE_EPOCH == 0:
            # val one epoch
            # for subi, filenames in enumerate([train0_filenames, val_filenames, test_filenames]):
            for subi, filenames in enumerate([
                    train0_filenames,
                    np.concatenate([val_filenames, test_filenames], axis=0)
            ]):
                dice_val = []
                if args.use_crf:
                    dice_val_crf = []
                for f in filenames:
                    subject_id = f[0]
                    img_fn = f[1]
                    case_name = img_fn.split('/')[-1]
                    # if '2141' in case_name:
                    #     continue

                    # Read the image nii with sitk and keep the pointer to the sitk.Image
                    # of an input
                    ct_sitk = sitk.ReadImage(
                        str(os.path.join(img_fn, 'InputCT_ROI.nii.gz')))
                    ct = sitk.GetArrayFromImage(ct_sitk).astype((np.float32))
                    pt_sitk = sitk.ReadImage(
                        str(os.path.join(img_fn, 'InputPET_SUV_ROI.nii.gz')))
                    pt = sitk.GetArrayFromImage(pt_sitk).astype((np.float32))
                    lbl_ct = sitk.GetArrayFromImage(
                        sitk.ReadImage(
                            str(
                                os.path.join(
                                    img_fn,
                                    'GTV_Primary_ROI_CT{}.nii.gz'.format(
                                        GT_POSTFIX))))).astype(np.uint8)
                    lbl_pt = sitk.GetArrayFromImage(
                        sitk.ReadImage(
                            str(
                                os.path.join(
                                    img_fn,
                                    'GTV_Primary_ROI_PET{}.nii.gz'.format(
                                        GT_POSTFIX))))).astype(np.uint8)

                    ct[ct > 200.] = 200.
                    ct[ct < -500.] = -500.
                    ct = 255 * (ct + 500) / (700.)

                    pt[pt < 0.01] = 0.01
                    pt[pt > 20.] = 20.
                    pt = 255 * (pt - 0.01) / (19.99)

                    image = np.concatenate(
                        [ct[..., np.newaxis], pt[..., np.newaxis]], axis=3)
                    label = np.concatenate(
                        [lbl_ct[..., np.newaxis], lbl_pt[..., np.newaxis]],
                        axis=3)

                    #                     if image_mean != None:
                    #                        image -= np.reshape(image_mean, [1,1,1,2])
                    #                     if image_std != None:
                    #                        image /= np.reshape(image_std, [1,1,1,2])

                    if args.use_crf:
                        pred, pred_crf = sess.run(
                            [pred_op, pred_crf_op],
                            feed_dict={
                                image_node: image[np.newaxis, ...],
                                label_node: label[np.newaxis, ...],
                                phase_train: False
                            })
                    else:
                        pred = sess.run(
                            [pred_op],
                            feed_dict={
                                image_node: image[np.newaxis, ...],
                                label_node: label[np.newaxis, ...],
                                phase_train: False
                            })
                    dice_val_ = computeDice(label[..., args.feat_index],
                                            pred[0])
                    dice_val.append(dice_val_)
                    if args.use_crf:
                        dice_val_crf_ = computeDice(
                            label[..., args.feat_index], pred_crf[0])
                        dice_val_crf.append(dice_val_crf_)
                    if args.use_crf:
                        log_file_handle.write('{} {} {}\n'.format(
                            case_name, dice_val_, dice_val_crf_))
                    else:
                        log_file_handle.write('{} {}\n'.format(
                            case_name, dice_val_))

                log_file_handle.write(
                    '{} {} Mean {} Dice(Before CRF): {}\n'.format(
                        step, subsets[subi], modes[args.feat_index],
                        np.mean(np.array(dice_val))))
                if args.use_crf:
                    log_file_handle.write(
                        '{} {} Mean {} Dice(After CRF): {}\n'.format(
                            step, subsets[subi], modes[args.feat_index],
                            np.mean(np.array(dice_val_crf))))

            if args.use_crf:
                crf_sw_value, crf_bw_value, crf_cm_value \
                = sess.run([crf_sw_op, crf_bw_op, crf_cm_op])
                print('sw:', crf_sw_value)
                print('bw:', crf_bw_value)
                print('cm:', crf_cm_value)

        if step % STEPS_ONE_EPOCH == 0:
            if args.use_crf:
                checkpoint_path = os.path.join(log_dir, 'model_crf.ckpt')
            else:
                checkpoint_path = os.path.join(log_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        # train one epoch
        start_time = time.time()
        image_batch, label_batch, case_name_batch = sess.run(
            [train_images, train_labels, train_case_names])
        if args.use_crf:
            summary, _, loss_reg_value, \
            total_loss_crf_value, loss_crf_value, \
            total_loss_value, loss_value, \
            pred, pred_crf, \
            dice, dice_crf \
            = sess.run([summary_op, train_op, loss_reg_op,
                        total_loss_crf_op, loss_crf_op,
                        total_loss_op, loss_op,
                        pred_op, pred_crf_op,
                        dice_op, dice_crf_op],
                     feed_dict={image_node: image_batch,
                                label_node: label_batch,
                                phase_train: False})
        else:
            summary, _, loss_reg_value, \
            total_loss_value, loss_value, \
            pred, dice = sess.run([summary_op, train_op, loss_reg_op,
                                     total_loss_op, loss_op,
                                     pred_op, dice_op],
                             feed_dict={image_node: image_batch,
                                        label_node: label_batch,
                                        phase_train: True})

        duration = time.time() - start_time

        #         if args.use_crf:
        #             log_file_handle.write('{}, step {:d}, {:.6f}, ' \
        #                              '{:.6f}, {:.6f}, ' \
        #                              '{:.6f}, {:.6f}, ' \
        #                              '{:.3f}, {:.3f}, {}\n'.format(modes[args.feat_index],
        #                                                          step, loss_reg_value,
        #                                                          total_loss_crf_value, loss_crf_value,
        #                                                          total_loss_value, loss_value,
        #                                                          dice, dice_crf, case_name_batch[0]))
        #             log_file_handle.flush()
        #         else:
        #             log_file_handle.write('{}, step {:d}, {:.6f}, ' \
        #                              '{:.6f}, {:.6f}, ' \
        #                              '{:.3f}, {}\n'.format(modes[args.feat_index],
        #                                                          step, loss_reg_value,
        #                                                          total_loss_value, loss_value,
        #                                                          dice, case_name_batch[0]))
        #             log_file_handle.flush()

        if step % STEPS_ONE_EPOCH == 0:
            summary_writer.add_summary(summary, global_step=step)

        step += 1
        if step == max_steps:
            break

    log_file_handle.close()
    coord.request_stop()
    coord.join(threads)
    sess.close()
Exemple #38
0
def load_tf_model(model_path, tags=(tag_constants.SERVING, ), config=None):
    """Loads the model at the specified path.

  Args:
    model_path: the path to either session_bundle or SavedModel
    tags: the tags that determines the model to load.
    config: tf.ConfigProto containing session configuration options.

  Returns:
    A pair of (Session, map<string, SignatureDef>) objects.

  Raises:
    PredictionError: if the model could not be loaded.
  """
    if loader.maybe_saved_model_directory(model_path):
        try:
            logging.info("Importing tensorflow.contrib in load_tf_model")
            # pylint: disable=redefined-outer-name,unused-variable,g-import-not-at-top
            import tensorflow as tf
            from tensorflow.python.framework.ops import Graph
            # pylint: enable=redefined-outer-name,unused-variable,g-import-not-at-top
            if tf.__version__.startswith("1.0"):
                session = tf_session.Session(target="",
                                             graph=None,
                                             config=config)
            else:
                session = tf_session.Session(target="",
                                             graph=Graph(),
                                             config=config)
            meta_graph = loader.load(session,
                                     tags=list(tags),
                                     export_dir=model_path)
        except Exception as e:  # pylint: disable=broad-except
            raise PredictionError(
                PredictionError.FAILED_TO_LOAD_MODEL,
                "Failed to load the model due to bad model data."
                " tags: %s\n%s" % (list(tags), str(e)))
    else:
        raise PredictionError(
            PredictionError.FAILED_TO_LOAD_MODEL,
            "Cloud ML only supports TF 1.0 or above and models "
            "saved in SavedModel format.")

    if session is None:
        raise PredictionError(
            PredictionError.FAILED_TO_LOAD_MODEL,
            "Failed to create session when loading the model")

    if not meta_graph.signature_def:
        raise PredictionError(
            PredictionError.FAILED_TO_LOAD_MODEL,
            "MetaGraph must have at least one signature_def.")

    # Remove invalid signatures from the signature map.
    invalid_signatures = []
    for signature_name in meta_graph.signature_def:
        try:
            signature = meta_graph.signature_def[signature_name]
            _update_dtypes(session.graph, signature.inputs)
            _update_dtypes(session.graph, signature.outputs)
        except ValueError as e:
            logging.warn("Error updating signature %s: %s", signature_name,
                         str(e))
            invalid_signatures.append(signature_name)
    for signature_name in invalid_signatures:
        del meta_graph.signature_def[signature_name]

    return session, meta_graph.signature_def
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
                       output_arrays, tag_set, signature_key):
    """Converts a SavedModel to a frozen graph.

  Args:
    saved_model_dir: SavedModel directory to convert.
    input_arrays: List of input tensors to freeze graph with. Uses input arrays
      from SignatureDef when none are provided. (default None)
    input_shapes: Map of strings representing input tensor names to list of
      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
      Automatically determined when input shapes is None (e.g., {"foo" : None}).
      (default None)
    output_arrays: List of output tensors to freeze graph with. Uses output
      arrays from SignatureDef when none are provided. (default None)
    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze. All tags in the tag set must be present. (default "serve")
    signature_key: Key identifying SignatureDef containing inputs and outputs.

  Returns:
    frozen_graph_def: Frozen GraphDef.
    in_tensors: List of input tensors for the graph.
    out_tensors: List of output tensors for the graph.

  Raises:
    ValueError:
      SavedModel doesn't contain a MetaGraphDef identified by tag_set.
      signature_key is not in the MetaGraphDef.
      input_shapes does not match the length of input_arrays.
      input_arrays or output_arrays are not valid.
      Unable to load Session.
  """
    # Set default values for inputs if they are set to None.
    if signature_key is None:
        signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    if tag_set is None:
        tag_set = set([tag_constants.SERVING])

    # Read SignatureDef.
    meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
    signature_def = _get_signature_def(meta_graph, signature_key)
    inputs, outputs = _get_inputs_outputs(signature_def)

    graph = ops.Graph()
    with session.Session(graph=graph) as sess:
        # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory.
        loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)

        # Gets input and output tensors.
        # TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
        in_tensors = _get_tensors(graph, inputs, input_arrays)
        out_tensors = _get_tensors(graph, outputs, output_arrays)

        # Gets fully defined tensor shape.
        for tensor in in_tensors:
            if (input_shapes and tensor.name in input_shapes
                    and input_shapes[tensor.name] is not None):
                shape = input_shapes[tensor.name]
            else:
                shape = tensor.get_shape().as_list()
            tensor.set_shape(shape)

        output_names = [node.split(":")[0] for node in outputs]
        frozen_graph_def = tf_graph_util.convert_variables_to_constants(
            sess, graph.as_graph_def(), output_names)

        return frozen_graph_def, in_tensors, out_tensors
    raise ValueError("Unable to load Session.")
 def test_savedmodel_state_override(self):
   random_model = RandomStateSpaceModel(
       state_dimension=5,
       state_noise_dimension=4,
       configuration=state_space_model.StateSpaceModelConfiguration(
           exogenous_feature_columns=[layers.real_valued_column("exogenous")],
           dtype=dtypes.float64, num_features=1))
   estimator = estimators.StateSpaceRegressor(
       model=random_model,
       optimizer=gradient_descent.GradientDescentOptimizer(0.1))
   combined_input_fn = input_pipeline.WholeDatasetInputFn(
       input_pipeline.NumpyReader({
           feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4],
           feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.],
           "exogenous": [-1., -2., -3., -4.]
       }))
   estimator.train(combined_input_fn, steps=1)
   export_location = estimator.export_savedmodel(
       self.get_temp_dir(),
       estimator.build_raw_serving_input_receiver_fn())
   with ops.Graph().as_default() as graph:
     random_model.initialize_graph()
     with self.session(graph=graph) as session:
       variables.global_variables_initializer().run()
       evaled_start_state = session.run(random_model.get_start_state())
   evaled_start_state = [
       state_element[None, ...] for state_element in evaled_start_state]
   with ops.Graph().as_default() as graph:
     with self.session(graph=graph) as session:
       signatures = loader.load(
           session, [tag_constants.SERVING], export_location)
       first_split_filtering = saved_model_utils.filter_continuation(
           continue_from={
               feature_keys.FilteringResults.STATE_TUPLE: evaled_start_state},
           signatures=signatures,
           session=session,
           features={
               feature_keys.FilteringFeatures.TIMES: [1, 2],
               feature_keys.FilteringFeatures.VALUES: [1., 2.],
               "exogenous": [[-1.], [-2.]]})
       second_split_filtering = saved_model_utils.filter_continuation(
           continue_from=first_split_filtering,
           signatures=signatures,
           session=session,
           features={
               feature_keys.FilteringFeatures.TIMES: [3, 4],
               feature_keys.FilteringFeatures.VALUES: [3., 4.],
               "exogenous": [[-3.], [-4.]]
           })
       combined_filtering = saved_model_utils.filter_continuation(
           continue_from={
               feature_keys.FilteringResults.STATE_TUPLE: evaled_start_state},
           signatures=signatures,
           session=session,
           features={
               feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4],
               feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.],
               "exogenous": [[-1.], [-2.], [-3.], [-4.]]
           })
       split_predict = saved_model_utils.predict_continuation(
           continue_from=second_split_filtering,
           signatures=signatures,
           session=session,
           steps=1,
           exogenous_features={
               "exogenous": [[[-5.]]]})
       combined_predict = saved_model_utils.predict_continuation(
           continue_from=combined_filtering,
           signatures=signatures,
           session=session,
           steps=1,
           exogenous_features={
               "exogenous": [[[-5.]]]})
   for state_key, combined_state_value in combined_filtering.items():
     if state_key == feature_keys.FilteringResults.TIMES:
       continue
     self.assertAllClose(
         combined_state_value, second_split_filtering[state_key])
   for prediction_key, combined_value in combined_predict.items():
     self.assertAllClose(combined_value, split_predict[prediction_key])
Exemple #41
0
def run_saved_model_with_feed_dict(saved_model_dir,
                                   tag_set,
                                   signature_def_key,
                                   input_tensor_key_feed_dict,
                                   outdir,
                                   overwrite_flag,
                                   tf_debug=False):
    """Runs SavedModel and fetch all outputs.

  Runs the input dictionary through the MetaGraphDef within a SavedModel
  specified by the given tag_set and SignatureDef. Also save the outputs to file
  if outdir is not None.

  Args:
    saved_model_dir: Directory containing the SavedModel to execute.
    tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
        string format, separated by ','. For tag-set contains multiple tags, all
        tags must be passed in.
    signature_def_key: A SignatureDef key string.
    input_tensor_key_feed_dict: A dictionary maps input keys to numpy ndarrays.
    outdir: A directory to save the outputs to. If the directory doesn't exist,
        it will be created.
    overwrite_flag: A boolean flag to allow overwrite output file if file with
        the same name exists.
    tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
        intermediate Tensor values and runtime GraphDefs while running the
        SavedModel.

  Raises:
    RuntimeError: An error when output file already exists and overwrite is not
    enabled.
  """
    # Get a list of output tensor names.
    meta_graph_def = get_meta_graph_def(saved_model_dir, tag_set)

    # Re-create feed_dict based on input tensor name instead of key as session.run
    # uses tensor name.
    inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)
    inputs_feed_dict = {
        inputs_tensor_info[key].name: tensor
        for key, tensor in input_tensor_key_feed_dict.items()
    }
    # Get outputs
    outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)
    # Sort to preserve order because we need to go from value to key later.
    output_tensor_keys_sorted = sorted(outputs_tensor_info.keys())
    output_tensor_names_sorted = [
        outputs_tensor_info[tensor_key].name
        for tensor_key in output_tensor_keys_sorted
    ]

    with session.Session(graph=ops_lib.Graph()) as sess:
        loader.load(sess, tag_set.split(','), saved_model_dir)

        if tf_debug:
            sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)

        outputs = sess.run(output_tensor_names_sorted,
                           feed_dict=inputs_feed_dict)

        for i, output in enumerate(outputs):
            output_tensor_key = output_tensor_keys_sorted[i]
            print('Result for output key %s:\n%s' %
                  (output_tensor_key, output))

            # Only save if outdir is specified.
            if outdir:
                # Create directory if outdir does not exist
                if not os.path.isdir(outdir):
                    os.makedirs(outdir)
                output_full_path = os.path.join(outdir,
                                                output_tensor_key + '.npy')

                # If overwrite not enabled and file already exist, error out
                if not overwrite_flag and os.path.exists(output_full_path):
                    raise RuntimeError(
                        'Output file %s already exists. Add \"--overwrite\" to overwrite'
                        ' the existing output files.' % output_full_path)

                np.save(output_full_path, output)
                print('Output %s is saved to %s' %
                      (output_tensor_key, output_full_path))
Exemple #42
0
    def testClearExtraneousSavers(self):
        export_dir = os.path.join(test.get_temp_dir(),
                                  "test_clear_extraneous_savers")
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Create a variable and a Saver.
        with ops.Graph().as_default() as graph:
            with session.Session(target="",
                                 config=config_pb2.ConfigProto(
                                     device_count={"CPU": 2})) as sess:
                self._init_and_validate_variable(sess, "v", 42)

                # Add two Savers, which should be removed in
                # add_meta_graph_and_variables() in favor of the locally added one.
                saver1 = tf_saver.Saver()
                graph.add_to_collection(ops.GraphKeys.SAVERS, saver1)
                saver2 = tf_saver.Saver()
                graph.add_to_collection(ops.GraphKeys.SAVERS, saver2)

                # Confirm there are two SaverDefs.
                savers = graph.get_collection(ops.GraphKeys.SAVERS)
                self.assertEqual(2, len(savers))

                # Confirm there are two Save and two Restore ops.
                save_op_names = set([
                    x.name for x in graph.get_operations()
                    if x.type == "SaveV2"
                ])
                self.assertSetEqual(set(["save/SaveV2", "save_1/SaveV2"]),
                                    save_op_names)

                restore_op_names = set([
                    x.name for x in graph.get_operations()
                    if x.type == "RestoreV2"
                ])
                self.assertSetEqual(
                    set(["save/RestoreV2", "save_1/RestoreV2"]),
                    restore_op_names)

                # The SavedModel builder adds its own Saver' for a total of three.
                builder.add_meta_graph_and_variables(sess,
                                                     [tag_constants.TRAINING],
                                                     clear_devices=True)

        # Save the SavedModel to disk.
        builder.save()

        # Restore the graph.
        with ops.Graph().as_default() as graph:
            with self.test_session(graph=graph) as sess:
                loader.load(sess, [tag_constants.TRAINING], export_dir)
                self.assertEqual(
                    42,
                    ops.get_collection(
                        ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

                # Confirm that the reloaded graph has only one SaverDef.
                savers = ops.get_collection(ops.GraphKeys.SAVERS)
                self.assertEqual(1, len(savers))

                # The reloaded graph should have exactly one Save and one Restore op.
                save_op_names = set([
                    x.name for x in graph.get_operations()
                    if x.type == "SaveV2"
                ])
                self.assertSetEqual(set(["save_2/SaveV2"]), save_op_names)
                restore_op_names = set([
                    x.name for x in graph.get_operations()
                    if x.type == "RestoreV2"
                ])
                self.assertSetEqual(set(["save_2/RestoreV2"]),
                                    restore_op_names)
Exemple #43
0
  def test_export_savedmodel(self):
    tmpdir = tempfile.mkdtemp()
    est, export_input_fn = _build_estimator_for_export_tests(tmpdir)

    extra_file_name = os.path.join(compat.as_bytes(tmpdir),
                                   compat.as_bytes('my_extra_file'))
    extra_file = tf.gfile.GFile(extra_file_name, mode='w')
    extra_file.write(EXTRA_FILE_CONTENT)
    extra_file.close()
    assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}

    export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                   compat.as_bytes('export'))
    export_dir = est.export_savedmodel(export_dir_base, export_input_fn,
                                       assets_extra=assets_extra)

    self.assertTrue(tf.gfile.Exists(export_dir_base))
    self.assertTrue(tf.gfile.Exists(export_dir))
    self.assertTrue(tf.gfile.Exists(
        os.path.join(compat.as_bytes(export_dir),
                     compat.as_bytes('saved_model.pb'))))
    self.assertTrue(tf.gfile.Exists(
        os.path.join(compat.as_bytes(export_dir),
                     compat.as_bytes('variables'))))
    self.assertTrue(tf.gfile.Exists(
        os.path.join(compat.as_bytes(export_dir),
                     compat.as_bytes('variables/variables.index'))))
    self.assertTrue(tf.gfile.Exists(os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('variables/variables.data-00000-of-00001'))))

    self.assertTrue(tf.gfile.Exists(
        os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets'))))
    self.assertTrue(tf.gfile.Exists(
        os.path.join(compat.as_bytes(export_dir),
                     compat.as_bytes('assets/my_vocab_file'))))
    self.assertEqual(
        compat.as_bytes(VOCAB_FILE_CONTENT),
        compat.as_bytes(tf.gfile.GFile(
            os.path.join(compat.as_bytes(export_dir),
                         compat.as_bytes('assets/my_vocab_file'))).read()))

    expected_extra_path = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
    self.assertTrue(tf.gfile.Exists(
        os.path.join(compat.as_bytes(export_dir),
                     compat.as_bytes('assets.extra'))))
    self.assertTrue(tf.gfile.Exists(expected_extra_path))
    self.assertEqual(
        compat.as_bytes(EXTRA_FILE_CONTENT),
        compat.as_bytes(tf.gfile.GFile(expected_extra_path).read()))

    expected_vocab_file = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('my_vocab_file'))
    # Restore, to validate that the export was well-formed.
    with tf.Graph().as_default() as graph:
      with tf.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        assets = [x.eval()
                  for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)]
        self.assertItemsEqual([expected_vocab_file], assets)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertTrue('input_example_tensor' in graph_ops)
        self.assertTrue('ParseExample/ParseExample' in graph_ops)
        self.assertTrue('linear/linear/feature/matmul' in graph_ops)

    # cleanup
    tf.gfile.DeleteRecursively(tmpdir)
  def testStripDefaultAttrs(self):
    export_dir = self._get_export_dir("test_strip_default_attrs")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Add a graph with two float32 variables and a Complex Op composing them
    # with strip_default_attrs enabled.
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], strip_default_attrs=True)

    # Add a graph with the same float32 variables and a Complex Op composing
    # them with strip_default_attrs disabled.
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph(["bar"], strip_default_attrs=False)

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Loading graph "foo" via the loader must restore the defaults for the
    # "Complex" node based on the "Complex" OpDef in the Op registry.
    sess = session.Session(graph=ops.Graph())
    meta_graph_def = loader.load(sess, ["foo"], export_dir)
    complex_node = test_util.get_node_def_from_graph("complex",
                                                     meta_graph_def.graph_def)
    self.assertIn("T", complex_node.attr)
    self.assertIn("Tout", complex_node.attr)

    # Load graph "foo" from disk as-is to verify default attrs are stripped.
    # pylint: disable=protected-access
    saved_model_pb = loader_impl._parse_saved_model(export_dir)
    self.assertIsNotNone(saved_model_pb)
    # pylint: enable=protected-access

    meta_graph_foo_def = None
    meta_graph_bar_def = None
    for meta_graph_def in saved_model_pb.meta_graphs:
      if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
        meta_graph_foo_def = meta_graph_def
      elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
        meta_graph_bar_def = meta_graph_def

    self.assertIsNotNone(meta_graph_foo_def)
    self.assertIsNotNone(meta_graph_bar_def)

    # "Complex" Op has 2 attributes with defaults:
    #   o "T"    : float32.   (input type)
    #   o "Tout" : complex64. (output type)

    # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
    # Graph "foo" was saved with strip_default_attrs set to True.
    node_def = test_util.get_node_def_from_graph("complex",
                                                 meta_graph_foo_def.graph_def)
    self.assertNotIn("T", node_def.attr)
    self.assertNotIn("Tout", node_def.attr)

    # "Complex" Op in graph "bar" must have attributes "T" and "Tout".
    # Graph "bar" was saved with strip_default_attrs set to False.
    node_def = test_util.get_node_def_from_graph("complex",
                                                 meta_graph_bar_def.graph_def)
    self.assertIn("T", node_def.attr)
    self.assertIn("Tout", node_def.attr)
Exemple #45
0
    def testTags(self):
        export_dir = os.path.join(compat.as_bytes(tf.test.get_temp_dir()),
                                  compat.as_bytes("tags"))
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Graph with a single variable. SavedModel invoked to:
        # - add with weights.
        # - a single tag (from predefined constants).
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(42, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(42, v.eval())
            builder.add_meta_graph_and_variables(sess,
                                                 [constants.TAG_TRAINING])

        # Graph that updates the single variable. SavedModel invoked to:
        # - simply add the model (weights are not updated).
        # - a single tag (from predefined constants).
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(43, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(43, v.eval())
            builder.add_meta_graph([constants.TAG_SERVING])

        # Graph that updates the single variable. SavedModel is invoked:
        # - to add the model (weights are not updated).
        # - multiple custom tags.
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(44, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(44, v.eval())
            builder.add_meta_graph(["foo", "bar"])

        # Save the SavedModel to disk.
        builder.save()

        # Restore the graph with a single predefined tag whose variables were saved.
        with self.test_session(graph=tf.Graph()) as sess:
            loader.load(sess, [constants.TAG_TRAINING], export_dir)
            tf.train.write_graph(sess.graph.as_graph_def(),
                                 "/tmp/",
                                 "training_graph",
                                 as_text=True)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

        # Restore the graph with a single predefined tag whose variables were not
        # saved.
        with self.test_session(graph=tf.Graph()) as sess:
            loader.load(sess, [constants.TAG_SERVING], export_dir)
            tf.train.write_graph(sess.graph.as_graph_def(),
                                 "/tmp/",
                                 "serving_graph",
                                 as_text=True)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

        # Restore the graph with multiple tags. Provide duplicate tags to test set
        # semantics.
        with self.test_session(graph=tf.Graph()) as sess:
            loader.load(sess, ["foo", "bar", "foo"], export_dir)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

        # Try restoring a graph with a non-existent tag. This should yield a runtime
        # error.
        with self.test_session(graph=tf.Graph()) as sess:
            self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"],
                              export_dir)

        # Try restoring a graph where a subset of the tags match. Since tag matching
        # for meta graph defs follows "all" semantics, this should yield a runtime
        # error.
        with self.test_session(graph=tf.Graph()) as sess:
            self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"],
                              export_dir)
Exemple #46
0
  def _fit_restore_fit_test_template(self, estimator_fn, dtype):
    """Tests restoring previously fit models."""
    model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
    exogenous_feature_columns = (
        feature_column.numeric_column("exogenous"),
    )
    first_estimator = estimator_fn(model_dir, exogenous_feature_columns)
    times = numpy.arange(20, dtype=numpy.int64)
    values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
    exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
    features = {
        feature_keys.TrainEvalFeatures.TIMES: times,
        feature_keys.TrainEvalFeatures.VALUES: values,
        "exogenous": exogenous
    }
    train_input_fn = input_pipeline.RandomWindowInputFn(
        input_pipeline.NumpyReader(features), shuffle_seed=2, num_threads=1,
        batch_size=16, window_size=16)
    eval_input_fn = input_pipeline.RandomWindowInputFn(
        input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
        batch_size=16, window_size=16)
    first_estimator.train(input_fn=train_input_fn, steps=5)
    first_loss_before_fit = first_estimator.evaluate(
        input_fn=eval_input_fn, steps=1)["loss"]
    first_estimator.train(input_fn=train_input_fn, steps=50)
    first_loss_after_fit = first_estimator.evaluate(
        input_fn=eval_input_fn, steps=1)["loss"]
    self.assertLess(first_loss_after_fit, first_loss_before_fit)
    second_estimator = estimator_fn(model_dir, exogenous_feature_columns)
    second_estimator.train(input_fn=train_input_fn, steps=2)
    whole_dataset_input_fn = input_pipeline.WholeDatasetInputFn(
        input_pipeline.NumpyReader(features))
    whole_dataset_evaluation = second_estimator.evaluate(
        input_fn=whole_dataset_input_fn, steps=1)
    exogenous_values_ten_steps = {
        "exogenous": numpy.arange(
            10, dtype=dtype.as_numpy_dtype)[None, :, None]
    }
    predict_input_fn = input_pipeline.predict_continuation_input_fn(
        evaluation=whole_dataset_evaluation,
        exogenous_features=exogenous_values_ten_steps,
        steps=10)
    # Also tests that limit_epochs in predict_continuation_input_fn prevents
    # infinite iteration
    (estimator_predictions,
    ) = list(second_estimator.predict(input_fn=predict_input_fn))
    self.assertAllEqual([10, 1], estimator_predictions["mean"].shape)
    input_receiver_fn = first_estimator.build_raw_serving_input_receiver_fn()
    export_location = first_estimator.export_savedmodel(self.get_temp_dir(),
                                                        input_receiver_fn)
    with ops.Graph().as_default():
      with session.Session() as sess:
        signatures = loader.load(sess, [tag_constants.SERVING], export_location)
        # Test that prediction and filtering can continue from evaluation output
        saved_prediction = saved_model_utils.predict_continuation(
            continue_from=whole_dataset_evaluation,
            steps=10,
            exogenous_features=exogenous_values_ten_steps,
            signatures=signatures,
            session=sess)
        # Saved model predictions should be the same as Estimator predictions
        # starting from the same evaluation.
        for prediction_key, prediction_value in estimator_predictions.items():
          self.assertAllClose(prediction_value,
                              numpy.squeeze(
                                  saved_prediction[prediction_key], axis=0))
        first_filtering = saved_model_utils.filter_continuation(
            continue_from=whole_dataset_evaluation,
            features={
                feature_keys.FilteringFeatures.TIMES: times[None, -1] + 2,
                feature_keys.FilteringFeatures.VALUES: values[None, -1] + 2.,
                "exogenous": values[None, -1, None] + 12.
            },
            signatures=signatures,
            session=sess)
        # Test that prediction and filtering can continue from filtering output
        second_saved_prediction = saved_model_utils.predict_continuation(
            continue_from=first_filtering,
            steps=1,
            exogenous_features={
                "exogenous": numpy.arange(
                    1, dtype=dtype.as_numpy_dtype)[None, :, None]
            },
            signatures=signatures,
            session=sess)
        self.assertEqual(
            times[-1] + 3,
            numpy.squeeze(
                second_saved_prediction[feature_keys.PredictionResults.TIMES]))
        saved_model_utils.filter_continuation(
            continue_from=first_filtering,
            features={
                feature_keys.FilteringFeatures.TIMES: times[-1] + 3,
                feature_keys.FilteringFeatures.VALUES: values[-1] + 3.,
                "exogenous": values[-1, None] + 13.
            },
            signatures=signatures,
            session=sess)

        # Test cold starting
        six.assertCountEqual(
            self,
            [feature_keys.FilteringFeatures.TIMES,
             feature_keys.FilteringFeatures.VALUES,
             "exogenous"],
            signatures.signature_def[
                feature_keys.SavedModelLabels.COLD_START_FILTER].inputs.keys())
        batch_numpy_times = numpy.tile(
            numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1))
        batch_numpy_values = numpy.ones([10, 30, 1])
        state = saved_model_utils.cold_start_filter(
            signatures=signatures,
            session=sess,
            features={
                feature_keys.FilteringFeatures.TIMES: batch_numpy_times,
                feature_keys.FilteringFeatures.VALUES: batch_numpy_values,
                "exogenous": 10. + batch_numpy_values
            }
        )
        predict_times = numpy.tile(
            numpy.arange(30, 45, dtype=numpy.int64)[None, :], (10, 1))
        predictions = saved_model_utils.predict_continuation(
            continue_from=state,
            times=predict_times,
            exogenous_features={
                "exogenous": numpy.tile(numpy.arange(
                    15, dtype=dtype.as_numpy_dtype), (10,))[None, :, None]
            },
            signatures=signatures,
            session=sess)
        self.assertAllEqual([10, 15, 1], predictions["mean"].shape)
Exemple #47
0
    def test_one_shot_prediction_head_export(self):
        def _new_temp_dir():
            return os.path.join(test.get_temp_dir(), str(ops.uid()))

        model_dir = _new_temp_dir()
        categorical_column = feature_column.categorical_column_with_hash_bucket(
            key="categorical_exogenous_feature", hash_bucket_size=16)
        exogenous_feature_columns = [
            feature_column.numeric_column("2d_exogenous_feature", shape=(2, )),
            feature_column.embedding_column(
                categorical_column=categorical_column, dimension=10)
        ]
        estimator = ts_estimators.TimeSeriesRegressor(
            model=ar_model.ARModel(
                periodicities=10,
                input_window_size=10,
                output_window_size=6,
                num_features=5,
                exogenous_feature_columns=exogenous_feature_columns,
                prediction_model_factory=functools.partial(
                    ar_model.LSTMPredictionModel, num_units=10)),
            head_type=ts_head_lib.OneShotPredictionHead,
            model_dir=model_dir)

        def train_input_fn():
            num_range = math_ops.range(16, dtype=dtypes.int64)
            features = {
                feature_keys.TrainEvalFeatures.TIMES:
                array_ops.expand_dims(num_range, axis=0),
                feature_keys.TrainEvalFeatures.VALUES:
                array_ops.expand_dims(array_ops.tile(num_range[:, None],
                                                     [1, 5]),
                                      axis=0),
                "2d_exogenous_feature":
                array_ops.ones([1, 16, 2]),
                "categorical_exogenous_feature":
                array_ops.expand_dims(array_ops.tile(["strkey"], [16])[:,
                                                                       None],
                                      axis=0)
            }
            return features

        estimator.train(input_fn=train_input_fn, steps=5)
        result = estimator.evaluate(input_fn=train_input_fn, steps=1)
        self.assertIn("average_loss", result)
        self.assertNotIn(feature_keys.State.STATE_TUPLE, result)
        input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
        export_location = estimator.export_saved_model(_new_temp_dir(),
                                                       input_receiver_fn)
        graph = ops.Graph()
        with graph.as_default():
            with session_lib.Session() as session:
                signatures = loader.load(session, [tag_constants.SERVING],
                                         export_location)
                self.assertEqual([feature_keys.SavedModelLabels.PREDICT],
                                 list(signatures.signature_def.keys()))
                predict_signature = signatures.signature_def[
                    feature_keys.SavedModelLabels.PREDICT]
                six.assertCountEqual(self, [
                    feature_keys.FilteringFeatures.TIMES,
                    feature_keys.FilteringFeatures.VALUES,
                    "2d_exogenous_feature", "categorical_exogenous_feature"
                ], predict_signature.inputs.keys())
                features = {
                    feature_keys.TrainEvalFeatures.TIMES:
                    numpy.tile(
                        numpy.arange(35, dtype=numpy.int64)[None, :], [2, 1]),
                    feature_keys.TrainEvalFeatures.VALUES:
                    numpy.tile(
                        numpy.arange(20, dtype=numpy.float32)[None, :, None],
                        [2, 1, 5]),
                    "2d_exogenous_feature":
                    numpy.ones([2, 35, 2]),
                    "categorical_exogenous_feature":
                    numpy.tile(
                        numpy.array(["strkey"] * 35)[None, :, None], [2, 1, 1])
                }
                feeds = {
                    graph.as_graph_element(input_value.name):
                    features[input_key]
                    for input_key, input_value in
                    predict_signature.inputs.items()
                }
                fetches = {
                    output_key: graph.as_graph_element(output_value.name)
                    for output_key, output_value in
                    predict_signature.outputs.items()
                }
                output = session.run(fetches, feed_dict=feeds)
                self.assertEqual((2, 15, 5), output["mean"].shape)
        # Build a parsing input function, then make a tf.Example for it to parse.
        export_location = estimator.export_saved_model(
            _new_temp_dir(),
            estimator.build_one_shot_parsing_serving_input_receiver_fn(
                filtering_length=20, prediction_length=15))
        graph = ops.Graph()
        with graph.as_default():
            with session_lib.Session() as session:
                example = example_pb2.Example()
                times = example.features.feature[
                    feature_keys.TrainEvalFeatures.TIMES]
                values = example.features.feature[
                    feature_keys.TrainEvalFeatures.VALUES]
                times.int64_list.value.extend(range(35))
                for i in range(20):
                    values.float_list.value.extend([
                        float(i) * 2. + feature_number
                        for feature_number in range(5)
                    ])
                real_feature = example.features.feature["2d_exogenous_feature"]
                categortical_feature = example.features.feature[
                    "categorical_exogenous_feature"]
                for i in range(35):
                    real_feature.float_list.value.extend([1, 1])
                    categortical_feature.bytes_list.value.append(b"strkey")
                # Serialize the tf.Example for feeding to the Session
                examples = [example.SerializeToString()] * 2
                signatures = loader.load(session, [tag_constants.SERVING],
                                         export_location)
                predict_signature = signatures.signature_def[
                    feature_keys.SavedModelLabels.PREDICT]
                ((_, input_value), ) = predict_signature.inputs.items()
                feeds = {graph.as_graph_element(input_value.name): examples}
                fetches = {
                    output_key: graph.as_graph_element(output_value.name)
                    for output_key, output_value in
                    predict_signature.outputs.items()
                }
                output = session.run(fetches, feed_dict=feeds)
                self.assertEqual((2, 15, 5), output["mean"].shape)
    def testStripDefaultAttrsInconsistentConsumerDefaults(self):
        export_dir = os.path.join(
            test.get_temp_dir(),
            "test_strip_default_attrs_no_consumer_defaults")
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Add a graph with two float32 variables and a Complex Op composing them
        # with strip_default_attrs enabled. This must remove the following
        # defaults for the "Complex" Op:
        #   o "T"    : float32.   (input type)
        #   o "Tout" : complex64. (output type)
        with session.Session(graph=ops.Graph()) as sess:
            real_num = variables.Variable(1.0,
                                          dtype=dtypes.float32,
                                          name="real")
            imag_num = variables.Variable(2.0,
                                          dtype=dtypes.float32,
                                          name="imag")
            math_ops.complex(real_num, imag_num, name="complex")
            sess.run(variables.global_variables_initializer())
            builder.add_meta_graph_and_variables(sess, ["foo"],
                                                 strip_default_attrs=True)

        # Save the SavedModel to disk in text format.
        builder.save(as_text=True)

        # Update the Op registry to remove defaults for all attrs("T", "Tout") from
        # the "Complex" OpDef.
        complex_op_def = op_def_registry.get_registered_ops()["Complex"]
        original_complex_op_def = op_def_pb2.OpDef()
        original_complex_op_def.CopyFrom(complex_op_def)
        for attr_def in complex_op_def.attr:
            attr_def.ClearField("default_value")

        # Loading the SavedModel via the loader must fail because the SavedModel
        # does not have any attr values for the "Complex" node and the current
        # op registry does not have have any default values for the "Complex" op.
        sess = session.Session(graph=ops.Graph())
        with self.assertRaisesRegexp(
                ValueError,
                "Expected one attr with name .*T(out)?.* in name: \"complex\".*"
        ):
            loader.load(sess, ["foo"], export_dir)

        # Update the Op registry to change the defaults for attr "Tout"
        # (complex64 -> complex128).
        complex_op_def.CopyFrom(original_complex_op_def)
        for attr_def in complex_op_def.attr:
            if attr_def.name == "Tout":
                attr_def.default_value.type = types_pb2.DT_COMPLEX128

        # Loading the SavedModel via the loader must set "Tout" attr_value for the
        # "Complex" node according to the latest defaults (complex128). This is
        # expected to fail the model import as there is no OpKernel registered to
        # handle attrs "T" (float32) and "Tout" (complex128).
        sess = session.Session(graph=ops.Graph())
        with self.assertRaisesRegexp(
                errors.InvalidArgumentError,
                ".*No OpKernel was registered to support Op \'Complex\' with these "
                "attrs..*"):
            loader.load(sess, ["foo"], export_dir)
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not saver_lib.checkpoint_exists(input_checkpoint)):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
            saver = saver_lib.Saver(var_list=var_list,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
Exemple #50
0
 def load(self, tags, path):
     assert self.session is not None
     loader.load(self.session, tags, path)
Exemple #51
0
def run_saved_model_with_feed_dict(saved_model_dir,
                                   tag_set,
                                   signature_def_key,
                                   input_tensor_key_feed_dict,
                                   outdir,
                                   overwrite_flag,
                                   worker=None,
                                   init_tpu=False,
                                   use_tfrt=False,
                                   tf_debug=False):
  """Runs SavedModel and fetch all outputs.

  Runs the input dictionary through the MetaGraphDef within a SavedModel
  specified by the given tag_set and SignatureDef. Also save the outputs to file
  if outdir is not None.

  Args:
    saved_model_dir: Directory containing the SavedModel to execute.
    tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
        string format, separated by ','. For tag-set contains multiple tags, all
        tags must be passed in.
    signature_def_key: A SignatureDef key string.
    input_tensor_key_feed_dict: A dictionary maps input keys to numpy ndarrays.
    outdir: A directory to save the outputs to. If the directory doesn't exist,
        it will be created.
    overwrite_flag: A boolean flag to allow overwrite output file if file with
        the same name exists.
    worker: If provided, the session will be run on the worker.  Valid worker
        specification is a bns or gRPC path.
    init_tpu: If true, the TPU system will be initialized after the session
        is created.
    use_tfrt: If true, TFRT session will be used.
    tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
        intermediate Tensor values and runtime GraphDefs while running the
        SavedModel.

  Raises:
    ValueError: When any of the input tensor keys is not valid.
    RuntimeError: An error when output file already exists and overwrite is not
    enabled.
  """
  # Get a list of output tensor names.
  meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir,
                                                        tag_set)

  # Re-create feed_dict based on input tensor name instead of key as session.run
  # uses tensor name.
  inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
      meta_graph_def, signature_def_key)

  # Check if input tensor keys are valid.
  for input_key_name in input_tensor_key_feed_dict.keys():
    if input_key_name not in inputs_tensor_info:
      raise ValueError(
          '"%s" is not a valid input key. Please choose from %s, or use '
          '--show option.' %
          (input_key_name, '"' + '", "'.join(inputs_tensor_info.keys()) + '"'))

  inputs_feed_dict = {
      inputs_tensor_info[key].name: tensor
      for key, tensor in input_tensor_key_feed_dict.items()
  }
  # Get outputs
  outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
      meta_graph_def, signature_def_key)
  # Sort to preserve order because we need to go from value to key later.
  output_tensor_keys_sorted = sorted(outputs_tensor_info.keys())
  output_tensor_names_sorted = [
      outputs_tensor_info[tensor_key].name
      for tensor_key in output_tensor_keys_sorted
  ]

  config = None
  if use_tfrt:
    logging.info('Using TFRT session.')
    config = config_pb2.ConfigProto(
        experimental=config_pb2.ConfigProto.Experimental(use_tfrt=True))
  with session.Session(worker, graph=ops_lib.Graph(), config=config) as sess:
    if init_tpu:
      print('Initializing TPU System ...')
      # This is needed for freshly started worker, or if the job
      # restarts after a preemption.
      sess.run(tpu.initialize_system())

    loader.load(sess, tag_set.split(','), saved_model_dir)

    if tf_debug:
      sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)

    outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict)

    for i, output in enumerate(outputs):
      output_tensor_key = output_tensor_keys_sorted[i]
      print('Result for output key %s:\n%s' % (output_tensor_key, output))

      # Only save if outdir is specified.
      if outdir:
        # Create directory if outdir does not exist
        if not os.path.isdir(outdir):
          os.makedirs(outdir)
        output_full_path = os.path.join(outdir, output_tensor_key + '.npy')

        # If overwrite not enabled and file already exist, error out
        if not overwrite_flag and os.path.exists(output_full_path):
          raise RuntimeError(
              'Output file %s already exists. Add \"--overwrite\" to overwrite'
              ' the existing output files.' % output_full_path)

        np.save(output_full_path, output)
        print('Output %s is saved to %s' % (output_tensor_key,
                                            output_full_path))
Exemple #52
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_parition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_parition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    print(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                    return -1
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
    def testSignatureDefs(self):
        export_dir = os.path.join(compat.as_bytes(tf.test.get_temp_dir()),
                                  compat.as_bytes("signature_defs"))
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Graph with a single variable and a single entry in the signature def map.
        # SavedModel is invoked to add with weights.
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(42, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(42, v.eval())
            # Build and populate an empty SignatureDef for testing.
            foo_signature = utils.build_signature_def(dict(), dict(), "foo")
            builder.add_meta_graph_and_variables(
                sess, ["foo"], signature_def_map={"foo_key": foo_signature})

        # Graph with the same single variable and multiple entries in the signature
        # def map. No weights are saved by SavedModel.
        with self.test_session(graph=tf.Graph()) as sess:
            v = tf.Variable(43, name="v")
            sess.run(tf.initialize_all_variables())
            self.assertEqual(43, v.eval())

            # Build and populate a different SignatureDef for testing.
            bar_signature = utils.build_signature_def(dict(), dict(), "bar")
            # Also, build a different SignatureDef corresponding to "foo_key" defined
            # in the previous graph.
            foo_new_signature = utils.build_signature_def(
                dict(), dict(), "foo_new")
            builder.add_meta_graph(["bar"],
                                   signature_def_map={
                                       "bar_key": bar_signature,
                                       "foo_key": foo_new_signature
                                   })

        # Save the SavedModel to disk.
        builder.save()

        # Restore the graph with tag "foo". The single entry in the SignatureDef map
        # corresponding to "foo_key" should exist.
        with self.test_session(graph=tf.Graph()) as sess:
            foo_graph = loader.load(sess, ["foo"], export_dir)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

            foo_signature = foo_graph.signature_def
            self.assertEqual(len(foo_signature), 1)
            self.assertEqual("foo", foo_signature["foo_key"].method_name)

        # Restore the graph with tag "bar". The SignatureDef map should have two
        # entries. One corresponding to "bar_key" and another corresponding to the
        # new value of "foo_key".
        with self.test_session(graph=tf.Graph()) as sess:
            bar_graph = loader.load(sess, ["bar"], export_dir)
            self.assertEqual(
                42,
                tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

            bar_signature = bar_graph.signature_def
            self.assertEqual(len(bar_signature), 2)
            self.assertEqual("bar", bar_signature["bar_key"].method_name)
            self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
Exemple #54
0
def _run_model(iterator, args, tf_args):
    """mapPartitions function to run single-node inferencing from a checkpoint/saved_model, using the model's input/output mappings.

  Args:
    :iterator: input RDD partition iterator.
    :args: arguments for TFModel, in argparse format
    :tf_args: arguments for TensorFlow inferencing code, in argparse or ARGV format.

  Returns:
    An iterator of result data.
  """
    single_node_env(tf_args)

    logging.info("===== input_mapping: {}".format(args.input_mapping))
    logging.info("===== output_mapping: {}".format(args.output_mapping))
    input_tensor_names = [
        tensor for col, tensor in sorted(args.input_mapping.items())
    ]
    output_tensor_names = [
        tensor for tensor, col in sorted(args.output_mapping.items())
    ]

    # if using a signature_def_key, get input/output tensor info from the requested signature
    if args.signature_def_key:
        assert args.export_dir, "Inferencing with signature_def_key requires --export_dir argument"
        logging.info(
            "===== loading meta_graph_def for tag_set ({0}) from saved_model: {1}"
            .format(args.tag_set, args.export_dir))
        meta_graph_def = get_meta_graph_def(args.export_dir, args.tag_set)
        signature = signature_def_utils.get_signature_def_by_key(
            meta_graph_def, args.signature_def_key)
        logging.debug("signature: {}".format(signature))
        inputs_tensor_info = signature.inputs
        logging.debug("inputs_tensor_info: {0}".format(inputs_tensor_info))
        outputs_tensor_info = signature.outputs
        logging.debug("outputs_tensor_info: {0}".format(outputs_tensor_info))

    result = []

    global global_sess, global_args
    if global_sess and global_args == args:
        # if graph/session already loaded/started (and using same args), just reuse it
        sess = global_sess
    else:
        # otherwise, create new session and load graph from disk
        tf.reset_default_graph()
        sess = tf.Session(graph=tf.get_default_graph())
        if args.export_dir:
            assert args.tag_set, "Inferencing from a saved_model requires --tag_set"
            # load graph from a saved_model
            logging.info("===== restoring from saved_model: {}".format(
                args.export_dir))
            loader.load(sess, args.tag_set.split(','), args.export_dir)
        elif args.model_dir:
            # load graph from a checkpoint
            ckpt = tf.train.latest_checkpoint(args.model_dir)
            assert ckpt, "Invalid model checkpoint path: {}".format(
                args.model_dir)
            logging.info("===== restoring from checkpoint: {}".format(ckpt +
                                                                      ".meta"))
            saver = tf.train.import_meta_graph(ckpt + ".meta",
                                               clear_devices=True)
            saver.restore(sess, ckpt)
        else:
            raise Exception(
                "Inferencing requires either --model_dir or --export_dir argument"
            )
        global_sess = sess
        global_args = args

    # get list of input/output tensors (by name)
    if args.signature_def_key:
        input_tensors = [
            inputs_tensor_info[t].name for t in input_tensor_names
        ]
        output_tensors = [outputs_tensor_info[output_tensor_names[0]].name]
    else:
        input_tensors = [t + ':0' for t in input_tensor_names]
        output_tensors = [t + ':0' for t in output_tensor_names]

    logging.info("input_tensors: {0}".format(input_tensors))
    logging.info("output_tensors: {0}".format(output_tensors))

    # feed data in batches and return output tensors
    for tensors in yield_batch(iterator, args.batch_size,
                               len(input_tensor_names)):
        inputs_feed_dict = {}
        for i in range(len(input_tensors)):
            inputs_feed_dict[input_tensors[i]] = tensors[i]

        outputs = sess.run(output_tensors, feed_dict=inputs_feed_dict)
        lengths = [len(output) for output in outputs]
        input_size = len(tensors[0])
        assert all([
            length == input_size for length in lengths
        ]), "Output array sizes {} must match input size: {}".format(
            lengths, input_size)
        python_outputs = [output.tolist() for output in outputs
                          ]  # convert from numpy to standard python types
        result.extend(zip(*python_outputs)
                      )  # convert to an array of tuples of "output columns"

    return result
    def __init__(self,
                 export_dir,
                 signature_def_key=None,
                 signature_def=None,
                 input_names=None,
                 output_names=None,
                 tags=None,
                 graph=None,
                 config=None):
        """Initialize a `CoreEstimatorPredictor`.

    Args:
      export_dir: a path to a directory containing a `SavedModel`.
      signature_def_key: Optional string specifying the signature to use. If
        `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of
        `signature_def_key` and `signature_def` should be specified.
      signature_def: A `SignatureDef` proto specifying the inputs and outputs
        for prediction. Only one of `signature_def_key` and `signature_def`
        should be specified.
      input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel`
        that represent the input. The keys can be any string of the user's
        choosing.
      output_names: A dictionary mapping strings to `Tensor`s in the
        `SavedModel` that represent the output. The keys can be any string of
        the user's choosing.
      tags: Optional. Comma separated list of tags that will be used to retrieve
        the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
      graph: Optional. The Tensorflow `graph` in which prediction should be
        done.
      config: `ConfigProto` proto used to configure the session.
    Raises:
      ValueError: If more than one of signature_def_key OR signature_def OR
        (input_names AND output_names) is specified.
    """
        _check_signature_arguments(signature_def_key, signature_def,
                                   input_names, output_names)
        tags = tags or DEFAULT_TAGS
        self._graph = graph or ops.Graph()

        with self._graph.as_default():
            self._session = session.Session(config=config)
            loader.load(self._session, tags.split(','), export_dir)

        if input_names is None:
            if signature_def is None:
                signature_def = _get_signature_def(signature_def_key,
                                                   export_dir, tags)
            input_names = {k: v.name for k, v in signature_def.inputs.items()}
            output_names = {
                k: v.name
                for k, v in signature_def.outputs.items()
            }

        self._feed_tensors = {
            k: self._graph.get_tensor_by_name(v)
            for k, v in input_names.items()
        }
        self._fetch_tensors = {
            k: self._graph.get_tensor_by_name(v)
            for k, v in output_names.items()
        }
 def test_one_shot_prediction_head_export(self):
     model_dir = self.get_temp_dir()
     categorical_column = feature_column.categorical_column_with_hash_bucket(
         key="categorical_exogenous_feature", hash_bucket_size=16)
     exogenous_feature_columns = [
         feature_column.numeric_column("2d_exogenous_feature", shape=(2, )),
         feature_column.embedding_column(
             categorical_column=categorical_column, dimension=10)
     ]
     estimator = ts_estimators.TimeSeriesRegressor(
         model=lstm_example._LSTMModel(
             num_features=5,
             num_units=128,
             exogenous_feature_columns=exogenous_feature_columns),
         optimizer=adam.AdamOptimizer(0.001),
         config=estimator_lib.RunConfig(tf_random_seed=4),
         state_manager=state_management.ChainingStateManager(),
         head_type=ts_head_lib.OneShotPredictionHead,
         model_dir=model_dir)
     train_features = {
         feature_keys.TrainEvalFeatures.TIMES:
         numpy.arange(20, dtype=numpy.int64),
         feature_keys.TrainEvalFeatures.VALUES:
         numpy.tile(numpy.arange(20, dtype=numpy.float32)[:, None], [1, 5]),
         "2d_exogenous_feature":
         numpy.ones([20, 2]),
         "categorical_exogenous_feature":
         numpy.array(["strkey"] * 20)[:, None]
     }
     train_input_fn = input_pipeline.RandomWindowInputFn(
         input_pipeline.NumpyReader(train_features),
         shuffle_seed=2,
         num_threads=1,
         batch_size=16,
         window_size=16)
     estimator.train(input_fn=train_input_fn, steps=5)
     input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
     export_location = estimator.export_savedmodel(self.get_temp_dir(),
                                                   input_receiver_fn)
     graph = ops.Graph()
     with graph.as_default():
         with session_lib.Session() as session:
             signatures = loader.load(session, [tag_constants.SERVING],
                                      export_location)
             self.assertEqual([feature_keys.SavedModelLabels.PREDICT],
                              list(signatures.signature_def.keys()))
             predict_signature = signatures.signature_def[
                 feature_keys.SavedModelLabels.PREDICT]
             six.assertCountEqual(self, [
                 feature_keys.FilteringFeatures.TIMES,
                 feature_keys.FilteringFeatures.VALUES,
                 "2d_exogenous_feature", "categorical_exogenous_feature"
             ], predict_signature.inputs.keys())
             features = {
                 feature_keys.TrainEvalFeatures.TIMES:
                 numpy.tile(
                     numpy.arange(35, dtype=numpy.int64)[None, :], [2, 1]),
                 feature_keys.TrainEvalFeatures.VALUES:
                 numpy.tile(
                     numpy.arange(20, dtype=numpy.float32)[None, :, None],
                     [2, 1, 5]),
                 "2d_exogenous_feature":
                 numpy.ones([2, 35, 2]),
                 "categorical_exogenous_feature":
                 numpy.tile(
                     numpy.array(["strkey"] * 35)[None, :, None], [2, 1, 1])
             }
             feeds = {
                 graph.as_graph_element(input_value.name):
                 features[input_key]
                 for input_key, input_value in
                 predict_signature.inputs.items()
             }
             fetches = {
                 output_key: graph.as_graph_element(output_value.name)
                 for output_key, output_value in
                 predict_signature.outputs.items()
             }
             output = session.run(fetches, feed_dict=feeds)
             self.assertAllEqual((2, 15, 5), output["mean"].shape)
Exemple #57
0
# Name of the output placeholder
model_outputs = "map_1/TensorArrayStack/TensorArrayGatherV3:0"

# Run our session
with tf.Session() as sess:
    # Print inputs and outputs (not required to run the script)
    graph = tf.Graph()
    with graph.as_default():
        metagraph = tf.saved_model.loader.load(sess, [tag_constants.SERVING],
                                               model_dir)
    inputs_mapping = dict(metagraph.signature_def['serving_default'].inputs)
    outputs_mapping = dict(metagraph.signature_def['serving_default'].outputs)
    #print (inputs_mapping)
    #print (outputs_mapping)
    '''
	Prepare the image to pass into the tf session. This will need to be improved
	to allow for all images in a given directory to be read as bytes, appended to
	a numpy array, then pass into the session.
	'''
    with open("StaM_6805_nf_duo_tripod_20171112_21_58_33p.JPG", "rb") as image:
        f = image.read()
        b = bytes(f)
        #print (b)

    # Run the inference
    loader.load(sess, [tag_constants.SERVING], model_dir).signature_def
    a = sess.run(model_outputs,
                 feed_dict={model_input: np.array(b).reshape(-1)})
    print(a)