def additive_coupling(name, x, x_mask, inverse, split_dim, identity_first, init, decoder_self_attention_bias=None, **kwargs): """Additive coupling transform layer.""" hparams = kwargs["hparams"] batch_size, length, n_channels = common_layers.shape_list(x) assert hparams.scale_width > 0.0 and hparams.scale_width < 1.0 with tf.variable_scope(name, reuse=tf.AUTO_REUSE): x_id, x_tr, _, n_transform, bias, mask = gops.split_coupling( x, x_mask, split_dim, identity_first, decoder_self_attention_bias) z_id = x_id loc = gops.transformer_decoder_block( "theta_tr", n_layers=hparams.n_layers_transform_params, x=x_id, x_mask=mask, output_size=n_transform, init=init, decoder_self_attention_bias=bias, **kwargs) if not inverse: z_tr = x_tr + loc else: z_tr = x_tr - loc logabsdet = tf.constant(0.0, dtype=tf.float32) tf.summary.histogram("_loc", tf.boolean_mask(loc, mask)) result = gops.join_coupling(z_id, z_tr, split_dim, identity_first) result = tf.reshape(result, [batch_size, length, n_channels]) return result, logabsdet
def affine_coupling(name, x, x_mask, inverse, split_dim, identity_first, init, decoder_self_attention_bias=None, **kwargs): """Affine coupling transform layer. Args: name: variable scope. x: 3-D Tensor, shape=[B, L, C]. x_mask : 2-D Tensor, shape=[B, L]. inverse: Forward or inverse pass. split_dim: which dimension to split (time, channel_continuous, channel_alternate). identity_first: True means the first half remains constant. False for 2nd. init: init. decoder_self_attention_bias: bias. **kwargs: additional arguments. Contains hparams, encoder_output and encoder_decoder_attention_bias. Returns: z: data transformed by the affine coupling layer. shape=[B, L, C] logabsdets: Log absolute determinant Jacobian. shape=[B] """ hparams = kwargs["hparams"] batch_size, length, n_channels = common_layers.shape_list(x) assert hparams.scale_width > 0.0 and hparams.scale_width < 1.0 with tf.variable_scope(name, reuse=tf.AUTO_REUSE): x_id, x_tr, _, n_transform, bias, mask = gops.split_coupling( x, x_mask, split_dim, identity_first, decoder_self_attention_bias) z_id = x_id transform_params = gops.transformer_decoder_block( "theta_tr", n_layers=hparams.n_layers_transform_params, x=x_id, x_mask=mask, output_size=n_transform * 2, init=init, decoder_self_attention_bias=bias, **kwargs) loc, unconstrained_scale = tf.split(transform_params, 2, axis=-1) scale = tf.sigmoid(unconstrained_scale + 2.0) if not inverse: z_tr = (x_tr + loc) * scale else: z_tr = x_tr / scale - loc logabsdet = gops.reduce_sum_over_lc(tf.log(scale), mask) # [B] if inverse: logabsdet *= -1 tf.summary.histogram("_loc", tf.boolean_mask(loc, mask)) tf.summary.histogram("_scale", tf.boolean_mask(scale, mask)) result = gops.join_coupling(z_id, z_tr, split_dim, identity_first) result = tf.reshape(result, [batch_size, length, n_channels]) return result, logabsdet