Example #1
0
    def _initial_call(self, new_inputs, length, **kwargs):
        """Returns Tensor of shape [..., 1, vocab_size].

    Args:
      new_inputs: Tensor of shape [..., vocab_size], the new input to generate
        its output.
      length: Length of final desired sequence.
      **kwargs: Optional keyword arguments to layer.
    """
        inputs = new_inputs[..., tf.newaxis, :]
        # TODO(trandustin): To handle variable lengths, extend MADE to subset its
        # input and output layer weights rather than pad inputs.
        batch_ndims = inputs.shape.ndims - 2
        padded_inputs = tf.pad(inputs,
                               paddings=[[0, 0]] * batch_ndims +
                               [[0, length - 1], [0, 0]])
        temperature = 1.
        logits = self.layer(padded_inputs / temperature, **kwargs)
        logits = logits[..., 0:1, :]
        logits = tf.reshape(
            logits,
            logits.shape[:-1].concatenate([self.vocab_size, self.vocab_size]))
        soft = utils.sinkhorn(logits)
        hard = tf.cast(utils.soft_to_hard_permutation(soft), inputs.dtype)
        hard = tf.reshape(hard, logits.shape)
        # Inverse of permutation matrix is its transpose.
        # inputs is [batch_size, timestep + 1, vocab_size].
        # hard is [batch_size, timestep + 1, vocab_size, vocab_size].
        outputs = tf.matmul(inputs[..., tf.newaxis, :], hard,
                            transpose_b=True)[..., 0, :]
        return outputs
Example #2
0
    def reverse(self, inputs, **kwargs):
        """Reverse pass returning the inverse autoregressive transformation."""
        if not self.built:
            self._maybe_build(inputs)

        logits = self.layer(inputs, **kwargs)
        logits = tf.reshape(
            logits,
            logits.shape[:-1].concatenate([self.vocab_size, self.vocab_size]))
        soft = utils.sinkhorn(logits / self.temperature, n_iters=20)
        hard = utils.soft_to_hard_permutation(soft)
        hard = tf.reshape(hard, logits.shape)
        # Recover the permutation by right-multiplying by the permutation matrix.
        outputs = tf.matmul(inputs[..., tf.newaxis, :], hard)[..., 0, :]
        return outputs
Example #3
0
  def _per_timestep_call(self,
                         current_outputs,
                         new_inputs,
                         length,
                         timestep,
                         **kwargs):
    """Returns Tensor of shape [..., timestep+1, vocab_size].

    Args:
      current_outputs: Tensor of shape [..., timestep, vocab_size], the so-far
        generated sequence Tensor.
      new_inputs: Tensor of shape [..., vocab_size], the new input to generate
        its output given current_outputs.
      length: Length of final desired sequence.
      timestep: Current timestep.
      **kwargs: Optional keyword arguments to layer.
    """
    inputs = tf.concat([current_outputs,
                        new_inputs[..., tf.newaxis, :]], axis=-2)
    # TODO(trandustin): To handle variable lengths, extend MADE to subset its
    # input and output layer weights rather than pad inputs.
    batch_ndims = inputs.shape.ndims - 2
    padded_inputs = tf.pad(
        inputs,
        paddings=[[0, 0]] * batch_ndims + [[0, length - timestep - 1], [0, 0]])
    logits = self.layer(padded_inputs, **kwargs)
    logits = logits[..., :(timestep+1), :]
    logits = tf.reshape(
        logits,
        logits.shape[:-1].concatenate([self.vocab_size, self.vocab_size]))
    soft = utils.sinkhorn(logits / self.temperature)
    hard = tf.cast(utils.soft_to_hard_permutation(soft), inputs.dtype)
    hard = tf.reshape(hard, logits.shape)
    # Inverse of permutation matrix is its transpose.
    # inputs is [batch_size, timestep + 1, vocab_size].
    # hard is [batch_size, timestep + 1, vocab_size, vocab_size].
    new_outputs = tf.matmul(inputs[..., tf.newaxis, :],
                            hard,
                            transpose_b=True)[..., 0, :]
    outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]], axis=-2)
    if not tf.executing_eagerly():
      outputs.set_shape([None] * batch_ndims + [timestep+1, self.vocab_size])
    return outputs