def decode(z, batch, hparams, is_training=True, reuse=False): """Autoencoder decoder network. Args: z: Tensor. The latent variables. batch: NSynthReader batch for pitch information. hparams: HParams. Hyperparameters (unused). is_training: bool. Whether batch normalization should be computed in training mode. Defaults to True. reuse: bool. Whether the variable scope should be reused. Defaults to False. Returns: The output of the decoder, i.e. a synthetic x computed from z. """ del hparams with tf.variable_scope("decoder", reuse=reuse): z_pitch = utils.pitch_embeddings(batch, reuse=reuse) z = tf.concat([z, z_pitch], 3) h = utils.conv2d(z, [1, 1], [1, 1], 1024, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="0") h = utils.conv2d(h, [4, 4], [2, 2], 512, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="1") h = utils.conv2d(h, [4, 4], [2, 2], 512, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="2") h = utils.conv2d(h, [4, 4], [2, 2], 256, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="3") h = utils.conv2d(h, [4, 4], [2, 2], 256, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="4") h = utils.conv2d(h, [4, 4], [2, 2], 256, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="5") h = utils.conv2d(h, [4, 4], [2, 2], 128, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="6") h = utils.conv2d(h, [4, 4], [2, 2], 128, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="7") h = utils.conv2d(h, [5, 5], [2, 2], 128, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="8") h = utils.conv2d(h, [5, 5], [2, 1], 128, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="8_1") xhat = utils.conv2d(h, [1, 1], [1, 1], 1, is_training, activation_fn=tf.nn.sigmoid, batch_norm=False, scope="mag") return xhat
def decode(z, batch, hparams, is_training=True, reuse=False): """Autoencoder decoder network. Args: z: Tensor. The latent variables. batch: NSynthReader batch for pitch information. hparams: HParams. Hyperparameters (unused). is_training: bool. Whether batch normalization should be computed in training mode. Defaults to True. reuse: bool. Whether the variable scope should be reused. Defaults to False. Returns: The output of the decoder, i.e. a synthetic x computed from z. """ del hparams with tf.variable_scope("decoder", reuse=reuse): z_pitch = utils.pitch_embeddings(batch, reuse=reuse) z = tf.concat([z, z_pitch], 3) h = utils.conv2d( z, [1, 1], [1, 1], 1024, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="0") h = utils.conv2d( h, [4, 4], [2, 2], 512, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="1") h = utils.conv2d( h, [4, 4], [2, 2], 512, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="2") h = utils.conv2d( h, [4, 4], [2, 2], 256, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="3") h = utils.conv2d( h, [4, 4], [2, 2], 256, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="4") h = utils.conv2d( h, [4, 4], [2, 2], 256, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="5") h = utils.conv2d( h, [4, 4], [2, 2], 128, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="6") h = utils.conv2d( h, [4, 4], [2, 2], 128, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="7") h = utils.conv2d( h, [5, 5], [2, 2], 128, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="8") h = utils.conv2d( h, [5, 5], [2, 1], 128, is_training, activation_fn=utils.leaky_relu(), transpose=True, batch_norm=True, scope="8_1") xhat = utils.conv2d( h, [1, 1], [1, 1], 1, is_training, activation_fn=tf.nn.sigmoid, batch_norm=False, scope="mag") return xhat