def reverse(self, inputs, **kwargs): """Reverse pass for the inverse bipartite transformation.""" if not self.built: self._maybe_build(inputs) inputs = tf.convert_to_tensor(inputs) batch_ndims = inputs.shape.ndims - 2 mask = tf.reshape(tf.cast(self.mask, inputs.dtype), [1] * batch_ndims + [-1, 1]) masked_inputs = mask * inputs net = self.layer(masked_inputs, **kwargs) if net.shape[-1] == 2 * self.vocab_size: loc, scale = tf.split(net, 2, axis=-1) scale = tf.cast(utils.one_hot_argmax(scale, self.temperature), inputs.dtype) scaled_inputs = utils.one_hot_multiply(inputs, scale) elif net.shape[-1] == self.vocab_size: loc = net scaled_inputs = inputs else: raise ValueError( 'Output of layer does not have compatible dimensions.') loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) masked_outputs = (1. - mask) * utils.one_hot_add(loc, scaled_inputs) outputs = masked_inputs + masked_outputs return outputs
def reverse(self, inputs, **kwargs): """Reverse pass returning the inverse autoregressive transformation.""" if not self.built: self._maybe_build(inputs) net = self.layer(inputs, **kwargs) if net.shape[-1] == 2 * self.vocab_size: loc, scale = tf.split(net, 2, axis=-1) scale = tf.cast(utils.one_hot_argmax(scale, self.temperature), inputs.dtype) scaled_inputs = utils.one_hot_multiply(inputs, scale) elif net.shape[-1] == self.vocab_size: loc = net scaled_inputs = inputs else: raise ValueError('Output of layer does not have compatible dimensions.') loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) outputs = utils.one_hot_add(loc, scaled_inputs) return outputs