def call(self, z_pitch, z_vel, z=None): """Forward pass for the MIDI decoder. Args: z_pitch: Tensor containing encoded pitch in MIDI scale. [batch, time, 1]. z_vel: Tensor containing encoded velocity in MIDI scale. [batch, time, 1]. z: Additional non-MIDI latent tensor. [batch, time, n_z] Returns: A dictionary to feed into a processor group. """ # pylint: disable=unused-argument # x = tf.concat([z_pitch, z_vel], axis=-1) # TODO(jesse): Allow velocity. x = z_pitch x = self.net(x) if z is None else self.net([x, z]) if self.norm is not None: x = self.norm(x) x = self.dense_out(x) outputs = nn.split_to_dict(x, self.output_splits) if self.f0_residual: outputs['f0_midi'] += z_pitch outputs['f0_hz'] = core.midi_to_hz(outputs['f0_midi']) return outputs
def call(self, conditioning): """Updates conditioning with dictionary of decoder outputs.""" x = self.decode(conditioning) outputs = nn.split_to_dict(x, self.output_splits) if isinstance(outputs, dict): conditioning.update(outputs) else: raise ValueError('Decoder must output a dictionary of signals.') return conditioning
def test_output_is_correct(self): tensor_splits = (('x1', 1), ('x2', 2), ('x3', 3)) x1 = np.zeros((2, 3, 1), dtype=np.float32) + 1.0 x2 = np.zeros((2, 3, 2), dtype=np.float32) + 2.0 x3 = np.zeros((2, 3, 3), dtype=np.float32) + 3.0 x = tf.constant(np.concatenate([x1, x2, x3], axis=2)) output = nn.split_to_dict(x, tensor_splits) self.assertSetEqual(set(['x1', 'x2', 'x3']), set(output.keys())) self.assertAllEqual(x1, output.get('x1')) self.assertAllEqual(x2, output.get('x2')) self.assertAllEqual(x3, output.get('x3'))