def compute_and_check_summary_pb(self, name, audio, max_outputs=3, display_name=None, description=None, audio_tensor=None, feed_dict=None): """Use both `op` and `pb` to get a summary, asserting validity. "Validity" means that the `op` and `pb` functions must return the same protobufs, and also that each encoded audio value appears to be a valid WAV file. If either of these conditions fails, the test will immediately fail. Otherwise, the valid protobuf will be returned. Returns: A `Summary` protocol buffer. """ if audio_tensor is None: audio_tensor = tf.constant(audio) op = summary.op(name, audio_tensor, self.samples_per_second, max_outputs=max_outputs, display_name=display_name, description=description) pb = summary.pb(name, audio, self.samples_per_second, max_outputs=max_outputs, display_name=display_name, description=description) pb_via_op = self.pb_via_op(op, feed_dict=feed_dict) self.assertProtoEquals(pb, pb_via_op) audios = tf.make_ndarray(pb.value[0].tensor)[:, 0].tolist() invalid_audios = [x for x in audios if not x.startswith(b'RIFF')] self.assertFalse(invalid_audios) return pb
def test_audio(self): logdir = self.get_temp_dir() steps = (0, 1, 2) with test_util.FileWriter(logdir) as writer: for step in steps: event = event_pb2.Event() event.step = step event.wall_time = 456.75 * step audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2), (4, 10, 2)) audio_pb = audio_summary.pb( "foo", audio, labels=["one", "two", "three", "four"], sample_rate=44100, display_name="bar", description="baz", ) writer.add_summary(audio_pb.SerializeToString(), global_step=step) files = os.listdir(logdir) self.assertLen(files, 1) event_file = os.path.join(logdir, files[0]) loader = event_file_loader.RawEventFileLoader(event_file) input_events = [event_pb2.Event.FromString(x) for x in loader.Load()] new_events = [] initial_metadata = {} for input_event in input_events: migrated = self._migrate_event(input_event, initial_metadata=initial_metadata) new_events.extend(migrated) self.assertLen(new_events, 4) self.assertEqual(new_events[0].WhichOneof("what"), "file_version") for step in steps: with self.subTest("step %d" % step): new_event = new_events[step + 1] self.assertLen(new_event.summary.value, 1) value = new_event.summary.value[0] tensor = tensor_util.make_ndarray(value.tensor) self.assertEqual(tensor.shape, (3, )) # 4 clipped to max_outputs=3 self.assertStartsWith(tensor[0], b"RIFF") self.assertStartsWith(tensor[1], b"RIFF") if step == min(steps): metadata = value.metadata self.assertEqual( metadata.data_class, summary_pb2.DATA_CLASS_BLOB_SEQUENCE, ) self.assertEqual( metadata.plugin_data.plugin_name, audio_metadata.PLUGIN_NAME, ) else: self.assertFalse(value.HasField("metadata"))
def audio(self, *args, **kwargs): return summary.pb(*args, **kwargs)
def test_requires_wav_in_pb(self): with six.assertRaisesRegex(self, ValueError, 'Unknown encoding'): summary.pb('k488', self.stereo, 44100, encoding='pptx')
def test_requires_rank_3_in_pb(self): with six.assertRaisesRegex(self, ValueError, 'must have rank 3'): summary.pb('k488', np.array([[1, 2, 3], [4, 5, 6]]), 44100)
def test_requires_wav_in_pb(self): with six.assertRaisesRegex(self, ValueError, 'Unknown encoding'): summary.pb('k488', self.stereo, 44100, encoding='pptx')
def test_requires_rank_3_in_pb(self): with six.assertRaisesRegex(self, ValueError, 'must have rank 3'): summary.pb('k488', np.array([[1, 2, 3], [4, 5, 6]]), 44100)