def _get_var_list(model):
  """Returns list of all checkpointed saveable objects in the model."""
  var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph()
  return var_list
 def testNamingWithOptimizer(self):
     input_value = tf.constant([[3.]])
     model = MyModel()
     # A nuisance Model using the same optimizer. Its slot variables should not
     # go in the checkpoint, since it is never depended on.
     other_model = MyModel()
     optimizer = tf.compat.v1.train.AdamOptimizer(0.001)
     optimizer_step = tf.compat.v1.train.get_or_create_global_step()
     root_trackable = tf.train.Checkpoint(optimizer=optimizer,
                                          model=model,
                                          optimizer_step=optimizer_step)
     if tf.executing_eagerly():
         optimizer.minimize(lambda: model(input_value),
                            global_step=optimizer_step)
         optimizer.minimize(lambda: other_model(input_value),
                            global_step=optimizer_step)
     else:
         train_op = optimizer.minimize(model(input_value),
                                       global_step=optimizer_step)
         optimizer.minimize(other_model(input_value),
                            global_step=optimizer_step)
         self.evaluate(trackable_utils.gather_initializers(root_trackable))
         self.evaluate(train_op)
     named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
         root_trackable).serialize_object_graph()
     expected_checkpoint_names = (
         # Created in the root node, so no prefix.
         "optimizer_step",
         "model/_second/kernel",
         "model/_named_dense/kernel",
         "model/_named_dense/bias",
         # non-Layer dependency of the model
         "model/_non_layer/a_variable",
         # The optimizer creates two non-slot variables
         "optimizer/beta1_power",
         "optimizer/beta2_power",
         # Slot variables
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
     )
     suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
     expected_checkpoint_names = [
         name + suffix for name in expected_checkpoint_names
     ]
     named_variables = {v.name: v for v in named_variables}
     six.assertCountEqual(self, expected_checkpoint_names,
                          named_variables.keys())
     # Check that we've mapped to the right variable objects (not exhaustive)
     self.assertEqual("global_step",
                      named_variables["optimizer_step" + suffix].full_name)
     self.assertEqual(
         "my_model/dense_1/kernel",
         named_variables["model/_second/kernel" + suffix].full_name)
     self.assertEqual(
         "my_model/dense/kernel",
         named_variables["model/_named_dense/kernel" + suffix].full_name)
     self.assertEqual(
         "beta1_power",
         named_variables["optimizer/beta1_power" + suffix].full_name)
     self.assertEqual(
         "beta2_power",
         named_variables["optimizer/beta2_power" + suffix].full_name)
     # Spot check the generated protocol buffers.
     self.assertEqual("optimizer",
                      serialized_graph.nodes[0].children[1].local_name)
     optimizer_node = serialized_graph.nodes[
         serialized_graph.nodes[0].children[1].node_id]
     self.assertEqual("beta1_power", optimizer_node.children[0].local_name)
     self.assertEqual(
         "beta1_power", serialized_graph.nodes[
             optimizer_node.children[0].node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].full_name)
     # We strip off the :0 suffix, as variable.name-based saving does.
     self.assertEqual(
         "my_model/dense/kernel/Adam",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel/Adam:0",
         optimizer.get_slot(var=model._named_dense.kernel, name="m").name)
     self.assertEqual(
         "model/_named_dense/kernel" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].checkpoint_key)
     self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
     self.assertEqual(
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].checkpoint_key)
示例#3
0
    def testAddVariable(self):
        obj = NonLayerTrackable()
        with self.assertRaisesRegex(ValueError, "do not specify shape"):
            trackable_utils.add_variable(obj,
                                         name="shape_specified_twice",
                                         shape=[],
                                         initializer=1)
        constant_initializer = trackable_utils.add_variable(
            obj, name="constant_initializer", initializer=1)
        with variable_scope.variable_scope("some_variable_scope"):
            ones_initializer = trackable_utils.add_variable(
                obj,
                name="ones_initializer",
                shape=[2],
                initializer=init_ops.ones_initializer(dtype=dtypes.float32))
        bare_initializer = trackable_utils.add_variable(
            obj,
            name="bare_initializer",
            shape=[2, 2],
            dtype=dtypes.float64,
            initializer=init_ops.zeros_initializer)

        # Even in graph mode, there are no naming conflicts between objects, only
        # naming conflicts within an object.
        other_duplicate = resource_variable_ops.ResourceVariable(
            name="duplicate", initial_value=1.)
        duplicate = trackable_utils.add_variable(obj,
                                                 name="duplicate",
                                                 shape=[])
        with self.assertRaisesRegex(ValueError,
                                    "'duplicate'.*already declared"):
            trackable_utils.add_variable(obj, name="duplicate", shape=[])

        self.evaluate(trackable_utils.gather_initializers(obj))
        self.assertEqual("constant_initializer:0", constant_initializer.name)
        self.assertEqual(1, self.evaluate(constant_initializer))
        self.assertEqual("some_variable_scope/ones_initializer:0",
                         ones_initializer.name)
        self.assertAllEqual([1, 1], self.evaluate(ones_initializer))
        self.assertAllEqual([[0., 0.], [0., 0.]],
                            self.evaluate(bare_initializer))
        self.assertEqual("a_variable:0", obj.a_variable.name)
        self.assertEqual("duplicate:0", other_duplicate.name)
        if context.executing_eagerly():
            # When executing eagerly, there's no uniquification of variable names. The
            # checkpoint name will be the same.
            self.assertEqual("duplicate:0", duplicate.name)
        else:
            # The .name attribute may be globally influenced, but the checkpoint name
            # won't be (tested below).
            self.assertEqual("duplicate_1:0", duplicate.name)
        named_variables, _, _ = (
            graph_view.ObjectGraphView(obj).serialize_object_graph())
        expected_checkpoint_names = (
            "a_variable/.ATTRIBUTES/VARIABLE_VALUE",
            "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
            "constant_initializer/.ATTRIBUTES/VARIABLE_VALUE",
            "duplicate/.ATTRIBUTES/VARIABLE_VALUE",
            "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE",
        )
        six.assertCountEqual(self, expected_checkpoint_names,
                             [v.name for v in named_variables])
示例#4
0
    def testNamingWithOptimizer(self):
        input_value = constant_op.constant([[3.]])
        model = MyModel()
        # A nuisance Model using the same optimizer. Its slot variables should not
        # go in the checkpoint, since it is never depended on.
        other_model = MyModel()
        optimizer = adam.Adam(0.001)
        step = training_util.get_or_create_global_step()
        root_trackable = trackable_utils.Checkpoint(optimizer=optimizer,
                                                    model=model,
                                                    step=step)

        with backprop.GradientTape() as tape:
            loss = model(input_value)
        variables = model.trainable_variables
        gradients = tape.gradient(loss, variables)
        train_op = control_flow_ops.group(
            optimizer.apply_gradients(zip(gradients, variables)),
            step.assign_add(1))

        with backprop.GradientTape() as tape:
            loss = other_model(input_value)
        variables = other_model.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))

        self.evaluate(trackable_utils.gather_initializers(root_trackable))
        self.evaluate(train_op)
        named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
            root_trackable).serialize_object_graph()
        expected_slot_keys = (
            "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
            "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
            "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
            "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
            "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
            "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
        )
        expected_checkpoint_names = (
            # Created in the root node, so no prefix.
            "step",
            "model/_second/kernel",
            "model/_named_dense/kernel",
            "model/_named_dense/bias",
            # non-Layer dependency of the model
            "model/_non_layer/a_variable",
            "optimizer/learning_rate",
            "optimizer/beta_1",
            "optimizer/beta_2",
            "optimizer/iter",
            "optimizer/decay",
        ) + expected_slot_keys
        suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
        expected_checkpoint_names = [
            name + suffix for name in expected_checkpoint_names
        ]
        named_variables = {v.name: v for v in named_variables}
        six.assertCountEqual(self, expected_checkpoint_names,
                             named_variables.keys())
        # Check that we've mapped to the right variable objects (not exhaustive)
        self.assertEqual("global_step",
                         named_variables["step" + suffix].full_name)
        self.assertEqual(
            "my_model/dense_1/kernel",
            named_variables["model/_second/kernel" + suffix].full_name)
        self.assertEqual(
            "my_model/dense/kernel",
            named_variables["model/_named_dense/kernel" + suffix].full_name)
        self.assertEqual(
            "Adam/beta_1",
            named_variables["optimizer/beta_1" + suffix].full_name)
        self.assertEqual(
            "Adam/beta_2",
            named_variables["optimizer/beta_2" + suffix].full_name)
        # Spot check the generated protocol buffers.
        self.assertEqual("optimizer",
                         serialized_graph.nodes[0].children[1].local_name)
        optimizer_node = serialized_graph.nodes[
            serialized_graph.nodes[0].children[1].node_id]
        children = [node.local_name for node in optimizer_node.children]
        six.assertCountEqual(
            self,
            # hyper variable dependencies
            ["beta_1", "beta_2", "iter", "decay", "learning_rate"],
            children)
        serialized_slot_keys = []
        for slot in optimizer_node.slot_variables:
            for attribute in (serialized_graph.nodes[
                    slot.slot_variable_node_id].attributes):
                serialized_slot_keys.append(attribute.checkpoint_key)
        six.assertCountEqual(self,
                             [key + suffix for key in expected_slot_keys],
                             serialized_slot_keys)
示例#5
0
  def model_fn(features, labels, mode):
    """model_fn for keras Estimator."""
    model = _clone_and_build_model(
        mode=mode,
        keras_model=keras_model,
        custom_objects=custom_objects,
        features=features,
        labels=labels,
        optimizer_config=optimizer_config)
    model_output_names = []
    # We need to make sure that the output names of the last layer in the model
    # is the same for each of the cloned models. This is required for mirrored
    # strategy when we call regroup.
    if distribution_strategy_context.has_strategy():
      for name in model.output_names:
        name = re.compile(r'_\d$').sub('', name)
        model_output_names.append(name)
    else:
      model_output_names = model.output_names

    # Get inputs to EstimatorSpec
    predictions = dict(zip(model_output_names, model.outputs))

    loss = None
    train_op = None
    eval_metric_ops = None

    # Set loss and metric only during train and evaluate.
    if mode is not ModeKeys.PREDICT:
      if mode is ModeKeys.TRAIN:
        model._make_train_function()  # pylint: disable=protected-access
      else:
        model._make_test_function()  # pylint: disable=protected-access
      loss = model.total_loss

      eval_metric_ops = _convert_keras_metrics_to_estimator(model)

    # Set train_op only during train.
    if mode is ModeKeys.TRAIN:
      train_op = model.train_function.updates_op

    if (not model._is_graph_network and
        hasattr(keras_model, '_original_attributes_cache') and
        keras_model._original_attributes_cache is not None):
      # To avoid `model_fn` being destructive for the initial model argument.
      models.in_place_subclassed_model_state_restoration(keras_model)

    scaffold = None
    if save_object_ckpt:
      model._track_trackable(training_util.get_global_step(),
                             'estimator_global_step')
      # Create saver that maps variable names to object-checkpoint keys.
      object_graph = graph_view.ObjectGraphView(model)
      var_list = object_graph.frozen_saveable_objects()
      saver = saver_lib.Saver(var_list=var_list, sharded=True)
      saver._object_restore_saver = trackable_util.frozen_saver(model)
      scaffold = monitored_session.Scaffold(saver=saver)

    return model_fn_lib.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs={
            _DEFAULT_SERVING_KEY:
            export_lib.PredictOutput(predictions)
        },
        scaffold=scaffold
    )