Exemple #1
0
    def call(self, z_pitch, z_vel, z=None):
        """Forward pass for the MIDI decoder.

    Args:
      z_pitch: Tensor containing encoded pitch in MIDI scale. [batch, time, 1].
      z_vel: Tensor containing encoded velocity in MIDI scale. [batch, time, 1].
      z: Additional non-MIDI latent tensor. [batch, time, n_z]

    Returns:
      A dictionary to feed into a processor group.
    """
        # pylint: disable=unused-argument
        # x = tf.concat([z_pitch, z_vel], axis=-1)  # TODO(jesse): Allow velocity.
        x = z_pitch
        x = self.net(x) if z is None else self.net([x, z])

        if self.norm is not None:
            x = self.norm(x)

        x = self.dense_out(x)

        outputs = nn.split_to_dict(x, self.output_splits)

        if self.f0_residual:
            outputs['f0_midi'] += z_pitch

        outputs['f0_hz'] = core.midi_to_hz(outputs['f0_midi'])
        return outputs
Exemple #2
0
    def call(self, conditioning):
        """Updates conditioning with dictionary of decoder outputs."""
        x = self.decode(conditioning)
        outputs = nn.split_to_dict(x, self.output_splits)

        if isinstance(outputs, dict):
            conditioning.update(outputs)
        else:
            raise ValueError('Decoder must output a dictionary of signals.')
        return conditioning
Exemple #3
0
    def test_output_is_correct(self):
        tensor_splits = (('x1', 1), ('x2', 2), ('x3', 3))
        x1 = np.zeros((2, 3, 1), dtype=np.float32) + 1.0
        x2 = np.zeros((2, 3, 2), dtype=np.float32) + 2.0
        x3 = np.zeros((2, 3, 3), dtype=np.float32) + 3.0
        x = tf.constant(np.concatenate([x1, x2, x3], axis=2))

        output = nn.split_to_dict(x, tensor_splits)

        self.assertSetEqual(set(['x1', 'x2', 'x3']), set(output.keys()))
        self.assertAllEqual(x1, output.get('x1'))
        self.assertAllEqual(x2, output.get('x2'))
        self.assertAllEqual(x3, output.get('x3'))