Ejemplo n.º 1
0
    def test_audio(self):
        logdir = self.get_temp_dir()
        steps = (0, 1, 2)
        with test_util.FileWriter(logdir) as writer:
            for step in steps:
                event = event_pb2.Event()
                event.step = step
                event.wall_time = 456.75 * step
                audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2),
                                   (4, 10, 2))
                audio_pb = audio_summary.pb(
                    "foo",
                    audio,
                    labels=["one", "two", "three", "four"],
                    sample_rate=44100,
                    display_name="bar",
                    description="baz",
                )
                writer.add_summary(audio_pb.SerializeToString(),
                                   global_step=step)
        files = os.listdir(logdir)
        self.assertLen(files, 1)
        event_file = os.path.join(logdir, files[0])
        loader = event_file_loader.RawEventFileLoader(event_file)
        input_events = [event_pb2.Event.FromString(x) for x in loader.Load()]

        new_events = []
        initial_metadata = {}
        for input_event in input_events:
            migrated = self._migrate_event(input_event,
                                           initial_metadata=initial_metadata)
            new_events.extend(migrated)

        self.assertLen(new_events, 4)
        self.assertEqual(new_events[0].WhichOneof("what"), "file_version")
        for step in steps:
            with self.subTest("step %d" % step):
                new_event = new_events[step + 1]
                self.assertLen(new_event.summary.value, 1)
                value = new_event.summary.value[0]
                tensor = tensor_util.make_ndarray(value.tensor)
                self.assertEqual(tensor.shape,
                                 (3, ))  # 4 clipped to max_outputs=3
                self.assertStartsWith(tensor[0], b"RIFF")
                self.assertStartsWith(tensor[1], b"RIFF")
                if step == min(steps):
                    metadata = value.metadata
                    self.assertEqual(
                        metadata.data_class,
                        summary_pb2.DATA_CLASS_BLOB_SEQUENCE,
                    )
                    self.assertEqual(
                        metadata.plugin_data.plugin_name,
                        audio_metadata.PLUGIN_NAME,
                    )
                else:
                    self.assertFalse(value.HasField("metadata"))
Ejemplo n.º 2
0
    def test_graph_def(self):
        # Create a `GraphDef` and write it to disk as an event.
        logdir = self.get_temp_dir()
        writer = test_util.FileWriter(logdir)
        graph_def = graph_pb2.GraphDef()
        graph_def.node.add(name="alice", op="Person")
        graph_def.node.add(name="bob", op="Person")
        graph_def.node.add(
            name="friendship", op="Friendship", input=["alice", "bob"]
        )
        writer.add_graph(graph=None, graph_def=graph_def, global_step=123)
        writer.flush()

        # Read in the `Event` containing the written `graph_def`.
        files = os.listdir(logdir)
        self.assertLen(files, 1)
        event_file = os.path.join(logdir, files[0])
        self.assertIn("tfevents", event_file)
        loader = event_file_loader.RawEventFileLoader(event_file)
        events = [event_pb2.Event.FromString(x) for x in loader.Load()]
        self.assertLen(events, 2)
        self.assertEqual(events[0].WhichOneof("what"), "file_version")
        self.assertEqual(events[1].WhichOneof("what"), "graph_def")
        old_event = events[1]

        new_events = self._migrate_event(old_event)
        self.assertLen(new_events, 2)
        self.assertIs(new_events[0], old_event)
        new_event = new_events[1]

        self.assertEqual(new_event.WhichOneof("what"), "summary")
        self.assertLen(new_event.summary.value, 1)
        tensor = tensor_util.make_ndarray(new_event.summary.value[0].tensor)
        self.assertEqual(
            new_event.summary.value[0].metadata.data_class,
            summary_pb2.DATA_CLASS_BLOB_SEQUENCE,
        )
        self.assertEqual(
            new_event.summary.value[0].metadata.plugin_data.plugin_name,
            graphs_metadata.PLUGIN_NAME,
        )
        self.assertEqual(tensor.shape, (1,))
        new_graph_def_bytes = tensor[0]
        self.assertIsInstance(new_graph_def_bytes, bytes)
        self.assertGreaterEqual(len(new_graph_def_bytes), 16)
        new_graph_def = graph_pb2.GraphDef.FromString(new_graph_def_bytes)

        self.assertProtoEquals(graph_def, new_graph_def)
Ejemplo n.º 3
0
 def _LoaderForTestFile(self, filename):
     return event_file_loader.RawEventFileLoader(
         os.path.join(self.get_temp_dir(), filename))