コード例 #1
0
    def reverse(self, inputs, **kwargs):
        """Reverse pass for the inverse bipartite transformation."""
        if not self.built:
            self._maybe_build(inputs)

        inputs = tf.convert_to_tensor(inputs)
        batch_ndims = inputs.shape.ndims - 2
        mask = tf.reshape(tf.cast(self.mask, inputs.dtype),
                          [1] * batch_ndims + [-1, 1])
        masked_inputs = mask * inputs
        net = self.layer(masked_inputs, **kwargs)
        if net.shape[-1] == 2 * self.vocab_size:
            loc, scale = tf.split(net, 2, axis=-1)
            scale = tf.cast(utils.one_hot_argmax(scale, self.temperature),
                            inputs.dtype)
            scaled_inputs = utils.one_hot_multiply(inputs, scale)
        elif net.shape[-1] == self.vocab_size:
            loc = net
            scaled_inputs = inputs
        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                      inputs.dtype)
        masked_outputs = (1. - mask) * utils.one_hot_add(loc, scaled_inputs)
        outputs = masked_inputs + masked_outputs
        return outputs
コード例 #2
0
ファイル: discrete_flows.py プロジェクト: seyi/edward2
  def reverse(self, inputs, **kwargs):
    """Reverse pass returning the inverse autoregressive transformation."""
    if not self.built:
      self._maybe_build(inputs)

    net = self.layer(inputs, **kwargs)
    if net.shape[-1] == 2 * self.vocab_size:
      loc, scale = tf.split(net, 2, axis=-1)
      scale = tf.cast(utils.one_hot_argmax(scale, self.temperature),
                      inputs.dtype)
      scaled_inputs = utils.one_hot_multiply(inputs, scale)
    elif net.shape[-1] == self.vocab_size:
      loc = net
      scaled_inputs = inputs
    else:
      raise ValueError('Output of layer does not have compatible dimensions.')
    loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype)
    outputs = utils.one_hot_add(loc, scaled_inputs)
    return outputs