def testExample(self): class SlotManager(tracking.AutoTrackable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() slotdeps = self.slotdeps slots = [] slots.append( slotdeps.track(resource_variable_ops.ResourceVariable(3.), "x")) slots.append( slotdeps.track(resource_variable_ops.ResourceVariable(4.), "y")) slots.append( slotdeps.track(resource_variable_ops.ResourceVariable(5.), "x")) self.slots = data_structures.NoDependency(slots) manager = SlotManager() self.evaluate([v.initializer for v in manager.slots]) checkpoint = util.Checkpoint(slot_manager=manager) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = checkpoint.save(checkpoint_prefix) metadata = util.object_metadata(save_path) dependency_names = [] for node in metadata.nodes: for child in node.children: dependency_names.append(child.local_name) six.assertCountEqual( self, dependency_names, ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])
def testObjectMetadata(self): with context.eager_mode(): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") dense = core.Dense(1) checkpoint = trackable_utils.Checkpoint(dense=dense) dense(constant_op.constant([[1.]])) save_path = checkpoint.save(checkpoint_prefix) objects = trackable_utils.object_metadata(save_path) all_variable_names = [] for obj in objects.nodes: for attribute in obj.attributes: all_variable_names.append(attribute.full_name) self.assertIn("dense/kernel", all_variable_names)
def testTrainSpinn(self): """Test with fake toy SNLI data and GloVe vectors.""" # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") fake_train_file = self._create_test_data(snli_1_0_dir) vocab = data.load_vocabulary(self._temp_data_dir) word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) dev_data = data.SnliData(fake_train_file, word2index) test_data = data.SnliData(fake_train_file, word2index) # 2. Create a fake config. config = _test_spinn_config(data.WORD_VECTOR_LEN, 4, logdir=os.path.join( self._temp_data_dir, "logdir")) # 3. Test training of a SPINN model. trainer = spinn.train_or_infer_spinn(embed, word2index, train_data, dev_data, test_data, config) # 4. Load train loss values from the summary files and verify that they # decrease with training. summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] events = summary_test_util.events_from_file(summary_file) train_losses = [ event.summary.value[0].simple_value for event in events if event.summary.value and event.summary.value[0].tag == "train/loss" ] self.assertEqual(config.epochs, len(train_losses)) # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) object_graph = trackable_utils.object_metadata( checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: ckpt_variable_names.add(attribute.full_name) self.assertIn("global_step", ckpt_variable_names) for v in trainer.variables: variable_name = v.name[:v.name. index(":")] if ":" in v.name else v.name self.assertIn(variable_name, ckpt_variable_names)
def testObjectMetadata(self): if not tf.executing_eagerly(): self.skipTest("Run in eager mode only.") checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") dense = core.Dense(1) checkpoint = tf.train.Checkpoint(dense=dense) dense(tf.constant([[1.]])) save_path = checkpoint.save(checkpoint_prefix) objects = trackable_utils.object_metadata(save_path) all_variable_names = [] for obj in objects.nodes: for attribute in obj.attributes: all_variable_names.append(attribute.full_name) self.assertIn("dense/kernel", all_variable_names)
def testTrainSpinn(self): """Test with fake toy SNLI data and GloVe vectors.""" # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") fake_train_file = self._create_test_data(snli_1_0_dir) vocab = data.load_vocabulary(self._temp_data_dir) word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) dev_data = data.SnliData(fake_train_file, word2index) test_data = data.SnliData(fake_train_file, word2index) # 2. Create a fake config. config = _test_spinn_config( data.WORD_VECTOR_LEN, 4, logdir=os.path.join(self._temp_data_dir, "logdir")) # 3. Test training of a SPINN model. trainer = spinn.train_or_infer_spinn( embed, word2index, train_data, dev_data, test_data, config) # 4. Load train loss values from the summary files and verify that they # decrease with training. summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] events = summary_test_util.events_from_file(summary_file) train_losses = [event.summary.value[0].simple_value for event in events if event.summary.value and event.summary.value[0].tag == "train/loss"] self.assertEqual(config.epochs, len(train_losses)) # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) object_graph = trackable_utils.object_metadata( checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: ckpt_variable_names.add(attribute.full_name) self.assertIn("global_step", ckpt_variable_names) for v in trainer.variables: variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name self.assertIn(variable_name, ckpt_variable_names)
def dot_graph_from_checkpoint(save_path): r"""Visualizes an object-based checkpoint (from `tf.train.Checkpoint`). Useful for inspecting checkpoints and debugging loading issues. Example usage from Python (requires pydot): ```python import tensorflow as tf import pydot dot_string = tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt') parsed, = pydot.graph_from_dot_data(dot_string) parsed.write_svg('/tmp/tensorflow/visualized_checkpoint.svg') ``` Example command line usage: ```sh python -c "import tensorflow as tf;\ print(tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt'))"\ | dot -Tsvg > /tmp/tensorflow/checkpoint_viz.svg ``` Args: save_path: The checkpoint prefix, as returned by `tf.train.Checkpoint.save` or `tf.train.latest_checkpoint`. Returns: A graph in DOT format as a string. """ reader = pywrap_tensorflow.NewCheckpointReader(save_path) object_graph = trackable_utils.object_metadata(save_path) shape_map = reader.get_variable_to_shape_map() dtype_map = reader.get_variable_to_dtype_map() graph = 'digraph {\n' def _escape(name): return name.replace('"', '\\"') slot_ids = set() for node in object_graph.nodes: for slot_reference in node.slot_variables: slot_ids.add(slot_reference.slot_variable_node_id) for node_id, node in enumerate(object_graph.nodes): if (len(node.attributes) == 1 and node.attributes[0].name == trackable.VARIABLE_VALUE_KEY): if node_id in slot_ids: color = 'orange' tooltip_prefix = 'Slot variable' else: color = 'blue' tooltip_prefix = 'Variable' attribute = node.attributes[0] graph += ('N_%d [shape=point label="" color=%s width=.25' ' tooltip="%s %s shape=%s %s"]\n') % ( node_id, color, tooltip_prefix, _escape(attribute.full_name), shape_map[attribute.checkpoint_key], dtype_map[attribute.checkpoint_key].name) elif node.slot_variables: graph += ('N_%d [shape=point label="" width=.25 color=red,' 'tooltip="Optimizer"]\n') % node_id else: graph += 'N_%d [shape=point label="" width=.25]\n' % node_id for reference in node.children: graph += 'N_%d -> N_%d [label="%s"]\n' % ( node_id, reference.node_id, _escape(reference.local_name)) for slot_reference in node.slot_variables: graph += 'N_%d -> N_%d [label="%s" style=dotted]\n' % ( node_id, slot_reference.slot_variable_node_id, _escape(slot_reference.slot_name)) graph += 'N_%d -> N_%d [style=dotted]\n' % ( slot_reference.original_variable_node_id, slot_reference.slot_variable_node_id) graph += '}\n' return graph