def compute_and_check_summary_pb(self,
                                   name,
                                   audio,
                                   max_outputs=3,
                                   display_name=None,
                                   description=None,
                                   audio_tensor=None,
                                   feed_dict=None):
    """Use both `op` and `pb` to get a summary, asserting validity.

    "Validity" means that the `op` and `pb` functions must return the
    same protobufs, and also that each encoded audio value appears to be
    a valid WAV file. If either of these conditions fails, the test will
    immediately fail. Otherwise, the valid protobuf will be returned.

    Returns:
      A `Summary` protocol buffer.
    """
    if audio_tensor is None:
      audio_tensor = tf.constant(audio)
    op = summary.op(name, audio_tensor, self.samples_per_second,
                    max_outputs=max_outputs,
                    display_name=display_name, description=description)
    pb = summary.pb(name, audio, self.samples_per_second,
                    max_outputs=max_outputs,
                    display_name=display_name, description=description)
    pb_via_op = self.pb_via_op(op, feed_dict=feed_dict)
    self.assertProtoEquals(pb, pb_via_op)
    audios = tf.make_ndarray(pb.value[0].tensor)[:, 0].tolist()
    invalid_audios = [x for x in audios
                      if not x.startswith(b'RIFF')]
    self.assertFalse(invalid_audios)
    return pb
    def test_audio(self):
        logdir = self.get_temp_dir()
        steps = (0, 1, 2)
        with test_util.FileWriter(logdir) as writer:
            for step in steps:
                event = event_pb2.Event()
                event.step = step
                event.wall_time = 456.75 * step
                audio = tf.reshape(tf.linspace(0.0, 100.0, 4 * 10 * 2),
                                   (4, 10, 2))
                audio_pb = audio_summary.pb(
                    "foo",
                    audio,
                    labels=["one", "two", "three", "four"],
                    sample_rate=44100,
                    display_name="bar",
                    description="baz",
                )
                writer.add_summary(audio_pb.SerializeToString(),
                                   global_step=step)
        files = os.listdir(logdir)
        self.assertLen(files, 1)
        event_file = os.path.join(logdir, files[0])
        loader = event_file_loader.RawEventFileLoader(event_file)
        input_events = [event_pb2.Event.FromString(x) for x in loader.Load()]

        new_events = []
        initial_metadata = {}
        for input_event in input_events:
            migrated = self._migrate_event(input_event,
                                           initial_metadata=initial_metadata)
            new_events.extend(migrated)

        self.assertLen(new_events, 4)
        self.assertEqual(new_events[0].WhichOneof("what"), "file_version")
        for step in steps:
            with self.subTest("step %d" % step):
                new_event = new_events[step + 1]
                self.assertLen(new_event.summary.value, 1)
                value = new_event.summary.value[0]
                tensor = tensor_util.make_ndarray(value.tensor)
                self.assertEqual(tensor.shape,
                                 (3, ))  # 4 clipped to max_outputs=3
                self.assertStartsWith(tensor[0], b"RIFF")
                self.assertStartsWith(tensor[1], b"RIFF")
                if step == min(steps):
                    metadata = value.metadata
                    self.assertEqual(
                        metadata.data_class,
                        summary_pb2.DATA_CLASS_BLOB_SEQUENCE,
                    )
                    self.assertEqual(
                        metadata.plugin_data.plugin_name,
                        audio_metadata.PLUGIN_NAME,
                    )
                else:
                    self.assertFalse(value.HasField("metadata"))
示例#3
0
 def audio(self, *args, **kwargs):
     return summary.pb(*args, **kwargs)
示例#4
0
 def test_requires_wav_in_pb(self):
     with six.assertRaisesRegex(self, ValueError, 'Unknown encoding'):
         summary.pb('k488', self.stereo, 44100, encoding='pptx')
示例#5
0
 def test_requires_rank_3_in_pb(self):
     with six.assertRaisesRegex(self, ValueError, 'must have rank 3'):
         summary.pb('k488', np.array([[1, 2, 3], [4, 5, 6]]), 44100)
示例#6
0
 def test_requires_wav_in_pb(self):
   with six.assertRaisesRegex(self, ValueError, 'Unknown encoding'):
     summary.pb('k488', self.stereo, 44100, encoding='pptx')
示例#7
0
 def test_requires_rank_3_in_pb(self):
   with six.assertRaisesRegex(self, ValueError, 'must have rank 3'):
     summary.pb('k488', np.array([[1, 2, 3], [4, 5, 6]]), 44100)