예제 #1
0
 def call(self, inputs, **kwargs):
     """Forward pass for bipartite generation."""
     inputs = tf.convert_to_tensor(inputs)
     batch_ndims = inputs.shape.ndims - 2
     mask = tf.reshape(tf.cast(self.mask, inputs.dtype),
                       [1] * batch_ndims + [-1, 1])
     masked_inputs = mask * inputs
     net = self.layer(masked_inputs, **kwargs)
     if net.shape[-1] == 2 * self.vocab_size:
         loc, scale = tf.split(net, 2, axis=-1)
         loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                       inputs.dtype)
         scale = tf.cast(utils.one_hot_argmax(scale, self.temperature),
                         inputs.dtype)
         inverse_scale = utils.multiplicative_inverse(
             scale, self.vocab_size)
         shifted_inputs = utils.one_hot_minus(inputs, loc)
         masked_outputs = (1. - mask) * utils.one_hot_multiply(
             shifted_inputs, inverse_scale)
     elif net.shape[-1] == self.vocab_size:
         loc = net
         loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                       inputs.dtype)
         masked_outputs = (1. - mask) * utils.one_hot_minus(inputs, loc)
     else:
         raise ValueError(
             'Output of layer does not have compatible dimensions.')
     outputs = masked_inputs + masked_outputs
     return outputs
예제 #2
0
  def _initial_call(self, new_inputs, length, **kwargs):
    """Returns Tensor of shape [..., 1, vocab_size].

    Args:
      new_inputs: Tensor of shape [..., vocab_size], the new input to generate
        its output.
      length: Length of final desired sequence.
      **kwargs: Optional keyword arguments to layer.
    """
    inputs = new_inputs[..., tf.newaxis, :]
    # TODO(trandustin): To handle variable lengths, extend MADE to subset its
    # input and output layer weights rather than pad inputs.
    batch_ndims = inputs.shape.ndims - 2
    padded_inputs = tf.pad(
        inputs, paddings=[[0, 0]] * batch_ndims + [[0, length - 1], [0, 0]])
    net = self.layer(padded_inputs, **kwargs)
    if net.shape[-1] == 2 * self.vocab_size:
      loc, scale = tf.split(net, 2, axis=-1)
      loc = loc[..., 0:1, :]
      loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype)
      scale = scale[..., 0:1, :]
      scale = tf.cast(utils.one_hot_argmax(scale, self.temperature),
                      inputs.dtype)
      inverse_scale = utils.multiplicative_inverse(scale, self.vocab_size)
      shifted_inputs = utils.one_hot_minus(inputs, loc)
      outputs = utils.one_hot_multiply(shifted_inputs, inverse_scale)
    elif net.shape[-1] == self.vocab_size:
      loc = net
      loc = loc[..., 0:1, :]
      loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype)
      outputs = utils.one_hot_minus(inputs, loc)
    else:
      raise ValueError('Output of layer does not have compatible dimensions.')
    return outputs
예제 #3
0
    def _per_timestep_call(self, current_outputs, new_inputs, length, timestep,
                           **kwargs):
        """Returns Tensor of shape [..., timestep+1, vocab_size].

    Args:
      current_outputs: Tensor of shape [..., timestep, vocab_size], the so-far
        generated sequence Tensor.
      new_inputs: Tensor of shape [..., vocab_size], the new input to generate
        its output given current_outputs.
      length: Length of final desired sequence.
      timestep: Current timestep.
      **kwargs: Optional keyword arguments to layer.
    """
        inputs = tf.concat([current_outputs, new_inputs[..., tf.newaxis, :]],
                           axis=-2)
        # TODO(trandustin): To handle variable lengths, extend MADE to subset its
        # input and output layer weights rather than pad inputs.
        batch_ndims = inputs.shape.ndims - 2
        padded_inputs = tf.pad(inputs,
                               paddings=[[0, 0]] * batch_ndims +
                               [[0, length - timestep - 1], [0, 0]])
        net = self.layer(padded_inputs, **kwargs)
        if net.shape[-1] == 2 * self.vocab_size:
            loc, scale = tf.split(net, 2, axis=-1)
            loc = loc[..., :(timestep + 1), :]
            loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                          inputs.dtype)
            scale = scale[..., :(timestep + 1), :]
            scale = tf.cast(utils.one_hot_argmax(scale, self.temperature),
                            inputs.dtype)
            inverse_scale = utils.multiplicative_inverse(
                scale, self.vocab_size)
            shifted_inputs = utils.one_hot_minus(inputs, loc)
            new_outputs = utils.one_hot_multiply(shifted_inputs, inverse_scale)
        elif net.shape[-1] == self.vocab_size:
            loc = net
            loc = loc[..., :(timestep + 1), :]
            loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                          inputs.dtype)
            new_outputs = utils.one_hot_minus(inputs, loc)
        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]],
                            axis=-2)
        if not tf.executing_eagerly():
            outputs.set_shape([None] * batch_ndims +
                              [timestep + 1, self.vocab_size])
        return outputs