示例#1
0
    def testExample(self):
        class SlotManager(tracking.AutoTrackable):
            def __init__(self):
                self.slotdeps = containers.UniqueNameTracker()
                slotdeps = self.slotdeps
                slots = []
                slots.append(
                    slotdeps.track(resource_variable_ops.ResourceVariable(3.),
                                   "x"))
                slots.append(
                    slotdeps.track(resource_variable_ops.ResourceVariable(4.),
                                   "y"))
                slots.append(
                    slotdeps.track(resource_variable_ops.ResourceVariable(5.),
                                   "x"))
                self.slots = data_structures.NoDependency(slots)

        manager = SlotManager()
        self.evaluate([v.initializer for v in manager.slots])
        checkpoint = util.Checkpoint(slot_manager=manager)
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        save_path = checkpoint.save(checkpoint_prefix)
        metadata = util.object_metadata(save_path)
        dependency_names = []
        for node in metadata.nodes:
            for child in node.children:
                dependency_names.append(child.local_name)
        six.assertCountEqual(
            self, dependency_names,
            ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])
示例#2
0
    def testObjectMetadata(self):
        with context.eager_mode():
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            dense = core.Dense(1)
            checkpoint = trackable_utils.Checkpoint(dense=dense)
            dense(constant_op.constant([[1.]]))
            save_path = checkpoint.save(checkpoint_prefix)

        objects = trackable_utils.object_metadata(save_path)
        all_variable_names = []
        for obj in objects.nodes:
            for attribute in obj.attributes:
                all_variable_names.append(attribute.full_name)
        self.assertIn("dense/kernel", all_variable_names)
示例#3
0
    def testTrainSpinn(self):
        """Test with fake toy SNLI data and GloVe vectors."""

        # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
        snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
        fake_train_file = self._create_test_data(snli_1_0_dir)

        vocab = data.load_vocabulary(self._temp_data_dir)
        word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

        train_data = data.SnliData(fake_train_file, word2index)
        dev_data = data.SnliData(fake_train_file, word2index)
        test_data = data.SnliData(fake_train_file, word2index)

        # 2. Create a fake config.
        config = _test_spinn_config(data.WORD_VECTOR_LEN,
                                    4,
                                    logdir=os.path.join(
                                        self._temp_data_dir, "logdir"))

        # 3. Test training of a SPINN model.
        trainer = spinn.train_or_infer_spinn(embed, word2index, train_data,
                                             dev_data, test_data, config)

        # 4. Load train loss values from the summary files and verify that they
        #    decrease with training.
        summary_file = glob.glob(os.path.join(config.logdir,
                                              "events.out.*"))[0]
        events = summary_test_util.events_from_file(summary_file)
        train_losses = [
            event.summary.value[0].simple_value for event in events if
            event.summary.value and event.summary.value[0].tag == "train/loss"
        ]
        self.assertEqual(config.epochs, len(train_losses))

        # 5. Verify that checkpoints exist and contains all the expected variables.
        self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
        object_graph = trackable_utils.object_metadata(
            checkpoint_management.latest_checkpoint(config.logdir))
        ckpt_variable_names = set()
        for node in object_graph.nodes:
            for attribute in node.attributes:
                ckpt_variable_names.add(attribute.full_name)
        self.assertIn("global_step", ckpt_variable_names)
        for v in trainer.variables:
            variable_name = v.name[:v.name.
                                   index(":")] if ":" in v.name else v.name
            self.assertIn(variable_name, ckpt_variable_names)
示例#4
0
  def testObjectMetadata(self):
    if not tf.executing_eagerly():
      self.skipTest("Run in eager mode only.")

    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    dense = core.Dense(1)
    checkpoint = tf.train.Checkpoint(dense=dense)
    dense(tf.constant([[1.]]))
    save_path = checkpoint.save(checkpoint_prefix)

    objects = trackable_utils.object_metadata(save_path)
    all_variable_names = []
    for obj in objects.nodes:
      for attribute in obj.attributes:
        all_variable_names.append(attribute.full_name)
    self.assertIn("dense/kernel", all_variable_names)
示例#5
0
  def testTrainSpinn(self):
    """Test with fake toy SNLI data and GloVe vectors."""

    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    dev_data = data.SnliData(fake_train_file, word2index)
    test_data = data.SnliData(fake_train_file, word2index)

    # 2. Create a fake config.
    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"))

    # 3. Test training of a SPINN model.
    trainer = spinn.train_or_infer_spinn(
        embed, word2index, train_data, dev_data, test_data, config)

    # 4. Load train loss values from the summary files and verify that they
    #    decrease with training.
    summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
    events = summary_test_util.events_from_file(summary_file)
    train_losses = [event.summary.value[0].simple_value for event in events
                    if event.summary.value
                    and event.summary.value[0].tag == "train/loss"]
    self.assertEqual(config.epochs, len(train_losses))

    # 5. Verify that checkpoints exist and contains all the expected variables.
    self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
    object_graph = trackable_utils.object_metadata(
        checkpoint_management.latest_checkpoint(config.logdir))
    ckpt_variable_names = set()
    for node in object_graph.nodes:
      for attribute in node.attributes:
        ckpt_variable_names.add(attribute.full_name)
    self.assertIn("global_step", ckpt_variable_names)
    for v in trainer.variables:
      variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
      self.assertIn(variable_name, ckpt_variable_names)
示例#6
0
def dot_graph_from_checkpoint(save_path):
  r"""Visualizes an object-based checkpoint (from `tf.train.Checkpoint`).

  Useful for inspecting checkpoints and debugging loading issues.

  Example usage from Python (requires pydot):
  ```python
  import tensorflow as tf
  import pydot

  dot_string = tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt')
  parsed, = pydot.graph_from_dot_data(dot_string)
  parsed.write_svg('/tmp/tensorflow/visualized_checkpoint.svg')
  ```

  Example command line usage:
  ```sh
  python -c "import tensorflow as tf;\
    print(tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt'))"\
    | dot -Tsvg > /tmp/tensorflow/checkpoint_viz.svg
  ```

  Args:
    save_path: The checkpoint prefix, as returned by `tf.train.Checkpoint.save`
      or `tf.train.latest_checkpoint`.
  Returns:
    A graph in DOT format as a string.
  """
  reader = pywrap_tensorflow.NewCheckpointReader(save_path)
  object_graph = trackable_utils.object_metadata(save_path)
  shape_map = reader.get_variable_to_shape_map()
  dtype_map = reader.get_variable_to_dtype_map()
  graph = 'digraph {\n'
  def _escape(name):
    return name.replace('"', '\\"')
  slot_ids = set()
  for node in object_graph.nodes:
    for slot_reference in node.slot_variables:
      slot_ids.add(slot_reference.slot_variable_node_id)
  for node_id, node in enumerate(object_graph.nodes):
    if (len(node.attributes) == 1
        and node.attributes[0].name == trackable.VARIABLE_VALUE_KEY):
      if node_id in slot_ids:
        color = 'orange'
        tooltip_prefix = 'Slot variable'
      else:
        color = 'blue'
        tooltip_prefix = 'Variable'
      attribute = node.attributes[0]
      graph += ('N_%d [shape=point label="" color=%s width=.25'
                ' tooltip="%s %s shape=%s %s"]\n') % (
                    node_id,
                    color,
                    tooltip_prefix,
                    _escape(attribute.full_name),
                    shape_map[attribute.checkpoint_key],
                    dtype_map[attribute.checkpoint_key].name)
    elif node.slot_variables:
      graph += ('N_%d [shape=point label="" width=.25 color=red,'
                'tooltip="Optimizer"]\n') % node_id
    else:
      graph += 'N_%d [shape=point label="" width=.25]\n' % node_id
    for reference in node.children:
      graph += 'N_%d -> N_%d [label="%s"]\n' % (
          node_id, reference.node_id, _escape(reference.local_name))
    for slot_reference in node.slot_variables:
      graph += 'N_%d -> N_%d [label="%s" style=dotted]\n' % (
          node_id,
          slot_reference.slot_variable_node_id,
          _escape(slot_reference.slot_name))
      graph += 'N_%d -> N_%d [style=dotted]\n' % (
          slot_reference.original_variable_node_id,
          slot_reference.slot_variable_node_id)
  graph += '}\n'
  return graph
def dot_graph_from_checkpoint(save_path):
    r"""Visualizes an object-based checkpoint (from `tf.train.Checkpoint`).

  Useful for inspecting checkpoints and debugging loading issues.

  Example usage from Python (requires pydot):
  ```python
  import tensorflow as tf
  import pydot

  dot_string = tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt')
  parsed, = pydot.graph_from_dot_data(dot_string)
  parsed.write_svg('/tmp/tensorflow/visualized_checkpoint.svg')
  ```

  Example command line usage:
  ```sh
  python -c "import tensorflow as tf;\
    print(tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt'))"\
    | dot -Tsvg > /tmp/tensorflow/checkpoint_viz.svg
  ```

  Args:
    save_path: The checkpoint prefix, as returned by `tf.train.Checkpoint.save`
      or `tf.train.latest_checkpoint`.
  Returns:
    A graph in DOT format as a string.
  """
    reader = pywrap_tensorflow.NewCheckpointReader(save_path)
    object_graph = trackable_utils.object_metadata(save_path)
    shape_map = reader.get_variable_to_shape_map()
    dtype_map = reader.get_variable_to_dtype_map()
    graph = 'digraph {\n'

    def _escape(name):
        return name.replace('"', '\\"')

    slot_ids = set()
    for node in object_graph.nodes:
        for slot_reference in node.slot_variables:
            slot_ids.add(slot_reference.slot_variable_node_id)
    for node_id, node in enumerate(object_graph.nodes):
        if (len(node.attributes) == 1
                and node.attributes[0].name == trackable.VARIABLE_VALUE_KEY):
            if node_id in slot_ids:
                color = 'orange'
                tooltip_prefix = 'Slot variable'
            else:
                color = 'blue'
                tooltip_prefix = 'Variable'
            attribute = node.attributes[0]
            graph += ('N_%d [shape=point label="" color=%s width=.25'
                      ' tooltip="%s %s shape=%s %s"]\n') % (
                          node_id, color, tooltip_prefix,
                          _escape(attribute.full_name),
                          shape_map[attribute.checkpoint_key],
                          dtype_map[attribute.checkpoint_key].name)
        elif node.slot_variables:
            graph += ('N_%d [shape=point label="" width=.25 color=red,'
                      'tooltip="Optimizer"]\n') % node_id
        else:
            graph += 'N_%d [shape=point label="" width=.25]\n' % node_id
        for reference in node.children:
            graph += 'N_%d -> N_%d [label="%s"]\n' % (
                node_id, reference.node_id, _escape(reference.local_name))
        for slot_reference in node.slot_variables:
            graph += 'N_%d -> N_%d [label="%s" style=dotted]\n' % (
                node_id, slot_reference.slot_variable_node_id,
                _escape(slot_reference.slot_name))
            graph += 'N_%d -> N_%d [style=dotted]\n' % (
                slot_reference.original_variable_node_id,
                slot_reference.slot_variable_node_id)
    graph += '}\n'
    return graph