Exemplo n.º 1
0
 def _per_timestep_call(self, current_outputs, new_inputs, length, timestep,
                        **kwargs):
     """Returns Tensor of shape [..., timestep+1, vocab_size].
 Args:
   current_outputs: Tensor of shape [..., timestep, vocab_size], the so-far
     generated sequence Tensor.
   new_inputs: Tensor of shape [..., vocab_size], the new input to generate
     its output given current_outputs.
   length: Length of final desired sequence.
   timestep: Current timestep.
   **kwargs: Optional keyword arguments to layer.
 """
     inputs = tf.concat([current_outputs, new_inputs[..., tf.newaxis, :]],
                        axis=-2)
     batch_ndims = inputs.shape.ndims - 2
     padded_inputs = tf.pad(inputs,
                            paddings=[[0, 0]] * batch_ndims +
                            [[0, length - timestep - 1], [0, 0]])
     net = self.layer(padded_inputs, **kwargs)
     if net.shape[-1] == self.vocab_size:
         loc = net
         loc = loc[..., :(timestep + 1), :]
         loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                       inputs.dtype)
         new_outputs = utils.one_hot_minus(inputs, loc)
     else:
         raise ValueError(
             'Output of layer does not have compatible dimensions.')
     outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]],
                         axis=-2)
     if not tf.executing_eagerly():
         outputs.set_shape([None] * batch_ndims +
                           [timestep + 1, self.vocab_size])
     return outputs
Exemplo n.º 2
0
    def reverse(self, inputs, **kwargs):
        """Reverse pass returning the inverse autoregressive transformation."""
        if not self.built:
            self._maybe_build(inputs)

        net = self.layer(inputs, **kwargs)
        if net.shape[-1] == self.subset_K:  #!
            loc = net
        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                      inputs.dtype)

        # operate on subset:
        x = inputs
        x = tf.gather(x, self.shuffling, axis=-1)
        x1 = x[..., :self.subset_K]
        #x1 = super().call(x1)
        x1 = utils.one_hot_add(loc, x1)
        x2 = x[..., self.subset_K:]
        x = tf.concat([x1, x2], axis=-1)
        x = tf.gather(x, self.inverted_shuffling, axis=-1)
        outputs = x

        return outputs
Exemplo n.º 3
0
    def _initial_call(self, new_inputs, length, **kwargs):
        inputs = new_inputs[..., tf.newaxis, :]
        batch_ndims = inputs.shape.ndims - 2
        padded_inputs = tf.pad(inputs,
                               paddings=[[0, 0]] * batch_ndims +
                               [[0, length - 1], [0, 0]])
        net = self.layer(padded_inputs, **kwargs)
        if net.shape[-1] == self.subset_K:  #!
            loc = net
            loc = loc[..., 0:1, :]
            loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                          inputs.dtype)

            # operate on subset:
            x = inputs
            x = tf.gather(x, self.shuffling, axis=-1)
            x1 = x[..., :self.subset_K]
            #x1 = super().call(x1)
            x1 = utils.one_hot_minus(
                x1, loc)  #outputs =  utils.one_hot_minus(inputs, loc)
            x2 = x[..., self.subset_K:]
            x = tf.concat([x1, x2], axis=-1)
            x = tf.gather(x, self.inverted_shuffling, axis=-1)
            outputs = x

        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        return outputs
Exemplo n.º 4
0
    def log_prob(self, sample, eps_prob=1e-31):
        B = self.components.probs.shape[-2]
        component_probs = self.components.probs

        shift = utils.one_hot_argmax(self.logits, self.temperature)
        sample = utils.one_hot_add(sample[:, :, None, :], shift[None, :, :, :])

        prob = tf.reduce_sum(component_probs * sample + eps_prob,
                             -1)  # sum over categories => n x N x B
        log_prob = tf.math.log(prob) + np.log(1. / B)  # n x N x B
        log_prob = tf.math.reduce_logsumexp(
            log_prob, -1)  # sum over B mixture components  => n x N
        return tf.reduce_sum(log_prob, -1)  # sum over N
Exemplo n.º 5
0
    def sample_extm(self, n=1, *args, **kwargs):
        B = self.components.probs.shape[-2]
        dtype = self.dtype

        shift = utils.one_hot_argmax(self.logits,
                                     self.temperature)  # N x B x K
        sample = self.components.sample(n)  # n x N x B x K
        sample = utils.one_hot_minus(sample, shift)

        selected_flows = (np.arange(n) + np.random.randint(B)
                          ) % B  # allocate equally between all flows
        mask = tf.one_hot(selected_flows, B, dtype=dtype)
        mask = mask[:, None, :, None]  # n x 1 x B x 1
        sample = tf.reduce_sum(sample * mask, -2)  #  n x N x K
        return sample, mask
Exemplo n.º 6
0
    def reverse(self, inputs, **kwargs):
        """Reverse pass returning the inverse autoregressive transformation."""
        if not self.built:
            self._maybe_build(inputs)

        net = self.layer(inputs, **kwargs)
        if net.shape[-1] == self.vocab_size:
            loc = net
            scaled_inputs = inputs
        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                      inputs.dtype)
        outputs = utils.one_hot_add(loc, scaled_inputs)
        return outputs
Exemplo n.º 7
0
 def call(self, inputs, **kwargs):
     """Forward pass for bipartite generation."""
     inputs = tf.convert_to_tensor(inputs)
     batch_ndims = inputs.shape.ndims - 2
     mask = tf.reshape(tf.cast(self.mask, inputs.dtype),
                       [1] * batch_ndims + [-1, 1])
     masked_inputs = mask * inputs
     net = self.layer(masked_inputs, **kwargs)
     if net.shape[-1] == self.vocab_size:
         loc = net
         loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                       inputs.dtype)
         masked_outputs = (1. - mask) * utils.one_hot_minus(inputs, loc)
     else:
         raise ValueError(
             'Output of layer does not have compatible dimensions.')
     outputs = masked_inputs + masked_outputs
     return outputs
Exemplo n.º 8
0
    def reverse(self, inputs, **kwargs):
        """Reverse pass for the inverse bipartite transformation."""
        if not self.built:
            self._maybe_build(inputs)

        inputs = tf.convert_to_tensor(inputs)
        batch_ndims = inputs.shape.ndims - 2
        mask = tf.reshape(tf.cast(self.mask, inputs.dtype),
                          [1] * batch_ndims + [-1, 1])
        masked_inputs = mask * inputs
        net = self.layer(masked_inputs, **kwargs)
        if net.shape[-1] == self.vocab_size:
            loc = net
            scaled_inputs = inputs
        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                      inputs.dtype)
        masked_outputs = (1. - mask) * utils.one_hot_add(loc, scaled_inputs)
        outputs = masked_inputs + masked_outputs
        return outputs
Exemplo n.º 9
0
 def _initial_call(self, new_inputs, length, **kwargs):
     """Returns Tensor of shape [..., 1, vocab_size].
 Args:
   new_inputs: Tensor of shape [..., vocab_size], the new input to generate
     its output.
   length: Length of final desired sequence.
   **kwargs: Optional keyword arguments to layer.
 """
     inputs = new_inputs[..., tf.newaxis, :]
     batch_ndims = inputs.shape.ndims - 2
     padded_inputs = tf.pad(inputs,
                            paddings=[[0, 0]] * batch_ndims +
                            [[0, length - 1], [0, 0]])
     net = self.layer(padded_inputs, **kwargs)
     if net.shape[-1] == self.vocab_size:
         loc = net
         loc = loc[..., 0:1, :]
         loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                       inputs.dtype)
         outputs = utils.one_hot_minus(inputs, loc)
     else:
         raise ValueError(
             'Output of layer does not have compatible dimensions.')
     return outputs
Exemplo n.º 10
0
    def _per_timestep_call(self, current_outputs, new_inputs, length, timestep,
                           **kwargs):
        inputs = tf.concat([current_outputs, new_inputs[..., tf.newaxis, :]],
                           axis=-2)
        batch_ndims = inputs.shape.ndims - 2
        padded_inputs = tf.pad(inputs,
                               paddings=[[0, 0]] * batch_ndims +
                               [[0, length - timestep - 1], [0, 0]])
        net = self.layer(padded_inputs, **kwargs)
        if net.shape[-1] == self.subset_K:  #!
            loc = net
            loc = loc[..., :(timestep + 1), :]
            loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                          inputs.dtype)

            # operate on subset:
            x = inputs
            x = tf.gather(x, self.shuffling, axis=-1)
            x1 = x[..., :self.subset_K]
            #x1 = super().call(x1)
            x1 = utils.one_hot_minus(
                x1, loc)  #new_outputs =  utils.one_hot_minus(inputs, loc)
            x2 = x[..., self.subset_K:]
            x = tf.concat([x1, x2], axis=-1)
            x = tf.gather(x, self.inverted_shuffling, axis=-1)
            new_outputs = x

        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]],
                            axis=-2)
        if not tf.executing_eagerly():
            outputs.set_shape([None] * batch_ndims +
                              [timestep + 1, self.vocab_size])
        return outputs
 def call(self, x):
     shift = utils.one_hot_argmax(self.logits, self.temperature)
     shifted_inputs = utils.one_hot_minus(x, shift)
     scale = utils.one_hot_argmax(self.logits_scale, self.temperature)
     inverse_scale = utils.multiplicative_inverse(scale, self.K)
     return utils.one_hot_multiply(shifted_inputs, inverse_scale)
 def reverse_static(x, logits, temperature):
     shift = utils.one_hot_argmax(logits, temperature)
     return utils.one_hot_add(x, shift)
 def call_static(x, logits, temperature):
     shift = utils.one_hot_argmax(logits, temperature)
     return utils.one_hot_minus(x, shift)
 def reverse(self, x):
     scale = utils.one_hot_argmax(self.logits_scale, self.temperature)
     scaled_inputs = utils.one_hot_multiply(x, scale)
     shift = utils.one_hot_argmax(self.logits, self.temperature)
     return utils.one_hot_add(shift, scaled_inputs)