コード例 #1
0
    def _initial_call(self, new_inputs, length, **kwargs):
        inputs = new_inputs[..., tf.newaxis, :]
        batch_ndims = inputs.shape.ndims - 2
        padded_inputs = tf.pad(inputs,
                               paddings=[[0, 0]] * batch_ndims +
                               [[0, length - 1], [0, 0]])
        net = self.layer(padded_inputs, **kwargs)
        if net.shape[-1] == self.subset_K:  #!
            loc = net
            loc = loc[..., 0:1, :]
            loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                          inputs.dtype)

            # operate on subset:
            x = inputs
            x = tf.gather(x, self.shuffling, axis=-1)
            x1 = x[..., :self.subset_K]
            #x1 = super().call(x1)
            x1 = utils.one_hot_minus(
                x1, loc)  #outputs =  utils.one_hot_minus(inputs, loc)
            x2 = x[..., self.subset_K:]
            x = tf.concat([x1, x2], axis=-1)
            x = tf.gather(x, self.inverted_shuffling, axis=-1)
            outputs = x

        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        return outputs
コード例 #2
0
 def _per_timestep_call(self, current_outputs, new_inputs, length, timestep,
                        **kwargs):
     """Returns Tensor of shape [..., timestep+1, vocab_size].
 Args:
   current_outputs: Tensor of shape [..., timestep, vocab_size], the so-far
     generated sequence Tensor.
   new_inputs: Tensor of shape [..., vocab_size], the new input to generate
     its output given current_outputs.
   length: Length of final desired sequence.
   timestep: Current timestep.
   **kwargs: Optional keyword arguments to layer.
 """
     inputs = tf.concat([current_outputs, new_inputs[..., tf.newaxis, :]],
                        axis=-2)
     batch_ndims = inputs.shape.ndims - 2
     padded_inputs = tf.pad(inputs,
                            paddings=[[0, 0]] * batch_ndims +
                            [[0, length - timestep - 1], [0, 0]])
     net = self.layer(padded_inputs, **kwargs)
     if net.shape[-1] == self.vocab_size:
         loc = net
         loc = loc[..., :(timestep + 1), :]
         loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                       inputs.dtype)
         new_outputs = utils.one_hot_minus(inputs, loc)
     else:
         raise ValueError(
             'Output of layer does not have compatible dimensions.')
     outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]],
                         axis=-2)
     if not tf.executing_eagerly():
         outputs.set_shape([None] * batch_ndims +
                           [timestep + 1, self.vocab_size])
     return outputs
コード例 #3
0
    def sample_extm(self, n=1, *args, **kwargs):
        B = self.components.probs.shape[-2]
        dtype = self.dtype

        shift = utils.one_hot_argmax(self.logits,
                                     self.temperature)  # N x B x K
        sample = self.components.sample(n)  # n x N x B x K
        sample = utils.one_hot_minus(sample, shift)

        selected_flows = (np.arange(n) + np.random.randint(B)
                          ) % B  # allocate equally between all flows
        mask = tf.one_hot(selected_flows, B, dtype=dtype)
        mask = mask[:, None, :, None]  # n x 1 x B x 1
        sample = tf.reduce_sum(sample * mask, -2)  #  n x N x K
        return sample, mask
コード例 #4
0
    def test_one_hot_minus(self):
        """Test one_hot_minus (if max value position moves by -shift)."""
        K = 8
        vals = np.array([1., 3., 7.])
        shifts = np.array([7., 3., 1.])
        sums = one_hot.one_hot_minus(tf.one_hot(vals, K),
                                     tf.one_hot(shifts, K))

        # check if there's exactly one 1 per row and remaining are zeros:
        self.assertAllEqual((tf.reduce_sum(sums, -1)), 1, "row sum==1")
        self.assertAllEqual((tf.reduce_max(sums, -1)), 1,
                            "max cell value in each row==1")
        self.assertAllEqual((tf.reduce_min(sums, -1)), 0,
                            "min cell value in each row==0")
        # check if results are correct
        self.assertTrue((np.argmax(sums, -1) == (vals - shifts) % K).all(),
                        "correct results")
コード例 #5
0
 def call(self, inputs, **kwargs):
     """Forward pass for bipartite generation."""
     inputs = tf.convert_to_tensor(inputs)
     batch_ndims = inputs.shape.ndims - 2
     mask = tf.reshape(tf.cast(self.mask, inputs.dtype),
                       [1] * batch_ndims + [-1, 1])
     masked_inputs = mask * inputs
     net = self.layer(masked_inputs, **kwargs)
     if net.shape[-1] == self.vocab_size:
         loc = net
         loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                       inputs.dtype)
         masked_outputs = (1. - mask) * utils.one_hot_minus(inputs, loc)
     else:
         raise ValueError(
             'Output of layer does not have compatible dimensions.')
     outputs = masked_inputs + masked_outputs
     return outputs
コード例 #6
0
 def _initial_call(self, new_inputs, length, **kwargs):
     """Returns Tensor of shape [..., 1, vocab_size].
 Args:
   new_inputs: Tensor of shape [..., vocab_size], the new input to generate
     its output.
   length: Length of final desired sequence.
   **kwargs: Optional keyword arguments to layer.
 """
     inputs = new_inputs[..., tf.newaxis, :]
     batch_ndims = inputs.shape.ndims - 2
     padded_inputs = tf.pad(inputs,
                            paddings=[[0, 0]] * batch_ndims +
                            [[0, length - 1], [0, 0]])
     net = self.layer(padded_inputs, **kwargs)
     if net.shape[-1] == self.vocab_size:
         loc = net
         loc = loc[..., 0:1, :]
         loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                       inputs.dtype)
         outputs = utils.one_hot_minus(inputs, loc)
     else:
         raise ValueError(
             'Output of layer does not have compatible dimensions.')
     return outputs
コード例 #7
0
    def _per_timestep_call(self, current_outputs, new_inputs, length, timestep,
                           **kwargs):
        inputs = tf.concat([current_outputs, new_inputs[..., tf.newaxis, :]],
                           axis=-2)
        batch_ndims = inputs.shape.ndims - 2
        padded_inputs = tf.pad(inputs,
                               paddings=[[0, 0]] * batch_ndims +
                               [[0, length - timestep - 1], [0, 0]])
        net = self.layer(padded_inputs, **kwargs)
        if net.shape[-1] == self.subset_K:  #!
            loc = net
            loc = loc[..., :(timestep + 1), :]
            loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                          inputs.dtype)

            # operate on subset:
            x = inputs
            x = tf.gather(x, self.shuffling, axis=-1)
            x1 = x[..., :self.subset_K]
            #x1 = super().call(x1)
            x1 = utils.one_hot_minus(
                x1, loc)  #new_outputs =  utils.one_hot_minus(inputs, loc)
            x2 = x[..., self.subset_K:]
            x = tf.concat([x1, x2], axis=-1)
            x = tf.gather(x, self.inverted_shuffling, axis=-1)
            new_outputs = x

        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]],
                            axis=-2)
        if not tf.executing_eagerly():
            outputs.set_shape([None] * batch_ndims +
                              [timestep + 1, self.vocab_size])
        return outputs
 def call(self, x):
     shift = utils.one_hot_argmax(self.logits, self.temperature)
     shifted_inputs = utils.one_hot_minus(x, shift)
     scale = utils.one_hot_argmax(self.logits_scale, self.temperature)
     inverse_scale = utils.multiplicative_inverse(scale, self.K)
     return utils.one_hot_multiply(shifted_inputs, inverse_scale)
 def call_static(x, logits, temperature):
     shift = utils.one_hot_argmax(logits, temperature)
     return utils.one_hot_minus(x, shift)