def test_audio(self): logdir = self.get_temp_dir() steps = (0, 1, 2) with test_util.FileWriter(logdir) as writer: for step in steps: event = event_pb2.Event() event.step = step event.wall_time = 456.75 * step audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2)) audio_pb = audio_summary.pb( "foo", audio, labels=["one", "two", "three", "four"], sample_rate=44100, display_name="bar", description="baz", ) writer.add_summary(audio_pb.SerializeToString(), global_step=step) files = os.listdir(logdir) self.assertLen(files, 1) event_file = os.path.join(logdir, files[0]) loader = event_file_loader.RawEventFileLoader(event_file) input_events = [event_pb2.Event.FromString(x) for x in loader.Load()] new_events = [] initial_metadata = {} for input_event in input_events: migrated = self._migrate_event(input_event, initial_metadata=initial_metadata) new_events.extend(migrated) self.assertLen(new_events, 4) self.assertEqual(new_events[0].WhichOneof("what"), "file_version") for step in steps: with self.subTest("step %d" % step): new_event = new_events[step + 1] self.assertLen(new_event.summary.value, 1) value = new_event.summary.value[0] tensor = tensor_util.make_ndarray(value.tensor) self.assertEqual(tensor.shape, (3, )) # 4 clipped to max_outputs=3 self.assertStartsWith(tensor[0], b"RIFF") self.assertStartsWith(tensor[1], b"RIFF") if step == min(steps): metadata = value.metadata self.assertEqual( metadata.data_class, summary_pb2.DATA_CLASS_BLOB_SEQUENCE, ) self.assertEqual( metadata.plugin_data.plugin_name, audio_metadata.PLUGIN_NAME, ) else: self.assertFalse(value.HasField("metadata"))
def test_graph_def(self): # Create a `GraphDef` and write it to disk as an event. logdir = self.get_temp_dir() writer = test_util.FileWriter(logdir) graph_def = graph_pb2.GraphDef() graph_def.node.add(name="alice", op="Person") graph_def.node.add(name="bob", op="Person") graph_def.node.add( name="friendship", op="Friendship", input=["alice", "bob"] ) writer.add_graph(graph=None, graph_def=graph_def, global_step=123) writer.flush() # Read in the `Event` containing the written `graph_def`. files = os.listdir(logdir) self.assertLen(files, 1) event_file = os.path.join(logdir, files[0]) self.assertIn("tfevents", event_file) loader = event_file_loader.RawEventFileLoader(event_file) events = [event_pb2.Event.FromString(x) for x in loader.Load()] self.assertLen(events, 2) self.assertEqual(events[0].WhichOneof("what"), "file_version") self.assertEqual(events[1].WhichOneof("what"), "graph_def") old_event = events[1] new_events = self._migrate_event(old_event) self.assertLen(new_events, 2) self.assertIs(new_events[0], old_event) new_event = new_events[1] self.assertEqual(new_event.WhichOneof("what"), "summary") self.assertLen(new_event.summary.value, 1) tensor = tensor_util.make_ndarray(new_event.summary.value[0].tensor) self.assertEqual( new_event.summary.value[0].metadata.data_class, summary_pb2.DATA_CLASS_BLOB_SEQUENCE, ) self.assertEqual( new_event.summary.value[0].metadata.plugin_data.plugin_name, graphs_metadata.PLUGIN_NAME, ) self.assertEqual(tensor.shape, (1,)) new_graph_def_bytes = tensor[0] self.assertIsInstance(new_graph_def_bytes, bytes) self.assertGreaterEqual(len(new_graph_def_bytes), 16) new_graph_def = graph_pb2.GraphDef.FromString(new_graph_def_bytes) self.assertProtoEquals(graph_def, new_graph_def)
def _LoaderForTestFile(self, filename): return event_file_loader.RawEventFileLoader( os.path.join(self.get_temp_dir(), filename))