def _initial_call(self, new_inputs, length, **kwargs): """Returns Tensor of shape [..., 1, vocab_size]. Args: new_inputs: Tensor of shape [..., vocab_size], the new input to generate its output. length: Length of final desired sequence. **kwargs: Optional keyword arguments to layer. """ inputs = new_inputs[..., tf.newaxis, :] # TODO(trandustin): To handle variable lengths, extend MADE to subset its # input and output layer weights rather than pad inputs. batch_ndims = inputs.shape.ndims - 2 padded_inputs = tf.pad(inputs, paddings=[[0, 0]] * batch_ndims + [[0, length - 1], [0, 0]]) temperature = 1. logits = self.layer(padded_inputs / temperature, **kwargs) logits = logits[..., 0:1, :] logits = tf.reshape( logits, logits.shape[:-1].concatenate([self.vocab_size, self.vocab_size])) soft = utils.sinkhorn(logits) hard = tf.cast(utils.soft_to_hard_permutation(soft), inputs.dtype) hard = tf.reshape(hard, logits.shape) # Inverse of permutation matrix is its transpose. # inputs is [batch_size, timestep + 1, vocab_size]. # hard is [batch_size, timestep + 1, vocab_size, vocab_size]. outputs = tf.matmul(inputs[..., tf.newaxis, :], hard, transpose_b=True)[..., 0, :] return outputs
def reverse(self, inputs, **kwargs): """Reverse pass returning the inverse autoregressive transformation.""" if not self.built: self._maybe_build(inputs) logits = self.layer(inputs, **kwargs) logits = tf.reshape( logits, logits.shape[:-1].concatenate([self.vocab_size, self.vocab_size])) soft = utils.sinkhorn(logits / self.temperature, n_iters=20) hard = utils.soft_to_hard_permutation(soft) hard = tf.reshape(hard, logits.shape) # Recover the permutation by right-multiplying by the permutation matrix. outputs = tf.matmul(inputs[..., tf.newaxis, :], hard)[..., 0, :] return outputs
def _per_timestep_call(self, current_outputs, new_inputs, length, timestep, **kwargs): """Returns Tensor of shape [..., timestep+1, vocab_size]. Args: current_outputs: Tensor of shape [..., timestep, vocab_size], the so-far generated sequence Tensor. new_inputs: Tensor of shape [..., vocab_size], the new input to generate its output given current_outputs. length: Length of final desired sequence. timestep: Current timestep. **kwargs: Optional keyword arguments to layer. """ inputs = tf.concat([current_outputs, new_inputs[..., tf.newaxis, :]], axis=-2) # TODO(trandustin): To handle variable lengths, extend MADE to subset its # input and output layer weights rather than pad inputs. batch_ndims = inputs.shape.ndims - 2 padded_inputs = tf.pad( inputs, paddings=[[0, 0]] * batch_ndims + [[0, length - timestep - 1], [0, 0]]) logits = self.layer(padded_inputs, **kwargs) logits = logits[..., :(timestep+1), :] logits = tf.reshape( logits, logits.shape[:-1].concatenate([self.vocab_size, self.vocab_size])) soft = utils.sinkhorn(logits / self.temperature) hard = tf.cast(utils.soft_to_hard_permutation(soft), inputs.dtype) hard = tf.reshape(hard, logits.shape) # Inverse of permutation matrix is its transpose. # inputs is [batch_size, timestep + 1, vocab_size]. # hard is [batch_size, timestep + 1, vocab_size, vocab_size]. new_outputs = tf.matmul(inputs[..., tf.newaxis, :], hard, transpose_b=True)[..., 0, :] outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]], axis=-2) if not tf.executing_eagerly(): outputs.set_shape([None] * batch_ndims + [timestep+1, self.vocab_size]) return outputs