コード例 #1
0
    def reverse(self, inputs, **kwargs):
        """Reverse pass returning the inverse autoregressive transformation."""
        if not self.built:
            self._maybe_build(inputs)

        net = self.layer(inputs, **kwargs)
        if net.shape[-1] == self.subset_K:  #!
            loc = net
        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                      inputs.dtype)

        # operate on subset:
        x = inputs
        x = tf.gather(x, self.shuffling, axis=-1)
        x1 = x[..., :self.subset_K]
        #x1 = super().call(x1)
        x1 = utils.one_hot_add(loc, x1)
        x2 = x[..., self.subset_K:]
        x = tf.concat([x1, x2], axis=-1)
        x = tf.gather(x, self.inverted_shuffling, axis=-1)
        outputs = x

        return outputs
コード例 #2
0
    def log_prob(self, sample, eps_prob=1e-31):
        B = self.components.probs.shape[-2]
        component_probs = self.components.probs

        shift = utils.one_hot_argmax(self.logits, self.temperature)
        sample = utils.one_hot_add(sample[:, :, None, :], shift[None, :, :, :])

        prob = tf.reduce_sum(component_probs * sample + eps_prob,
                             -1)  # sum over categories => n x N x B
        log_prob = tf.math.log(prob) + np.log(1. / B)  # n x N x B
        log_prob = tf.math.reduce_logsumexp(
            log_prob, -1)  # sum over B mixture components  => n x N
        return tf.reduce_sum(log_prob, -1)  # sum over N
コード例 #3
0
    def test_one_hot_add(self):
        """Test one_hot_add (if max value position moves by shift)."""
        K = 8
        vals = np.array([1., 3., 7.])
        shifts = np.array([7., 3., 1.])
        sums = one_hot.one_hot_add(tf.one_hot(vals, K), tf.one_hot(shifts, K))

        # check if there's exactly one 1 per row and remaining are zeros:
        self.assertAllEqual((tf.reduce_sum(sums, -1)), 1, "row sum==1")
        self.assertAllEqual((tf.reduce_max(sums, -1)), 1,
                            "max cell value in each row==1")
        self.assertAllEqual((tf.reduce_min(sums, -1)), 0,
                            "min cell value in each row==0")
        # check if results are correct
        self.assertTrue((np.argmax(sums, -1) == (vals + shifts) % K).all(),
                        "correct results")
コード例 #4
0
    def reverse(self, inputs, **kwargs):
        """Reverse pass returning the inverse autoregressive transformation."""
        if not self.built:
            self._maybe_build(inputs)

        net = self.layer(inputs, **kwargs)
        if net.shape[-1] == self.vocab_size:
            loc = net
            scaled_inputs = inputs
        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                      inputs.dtype)
        outputs = utils.one_hot_add(loc, scaled_inputs)
        return outputs
コード例 #5
0
    def reverse(self, inputs, **kwargs):
        """Reverse pass for the inverse bipartite transformation."""
        if not self.built:
            self._maybe_build(inputs)

        inputs = tf.convert_to_tensor(inputs)
        batch_ndims = inputs.shape.ndims - 2
        mask = tf.reshape(tf.cast(self.mask, inputs.dtype),
                          [1] * batch_ndims + [-1, 1])
        masked_inputs = mask * inputs
        net = self.layer(masked_inputs, **kwargs)
        if net.shape[-1] == self.vocab_size:
            loc = net
            scaled_inputs = inputs
        else:
            raise ValueError(
                'Output of layer does not have compatible dimensions.')
        loc = tf.cast(utils.one_hot_argmax(loc, self.temperature),
                      inputs.dtype)
        masked_outputs = (1. - mask) * utils.one_hot_add(loc, scaled_inputs)
        outputs = masked_inputs + masked_outputs
        return outputs
 def reverse_static(x, logits, temperature):
     shift = utils.one_hot_argmax(logits, temperature)
     return utils.one_hot_add(x, shift)
 def reverse(self, x):
     scale = utils.one_hot_argmax(self.logits_scale, self.temperature)
     scaled_inputs = utils.one_hot_multiply(x, scale)
     shift = utils.one_hot_argmax(self.logits, self.temperature)
     return utils.one_hot_add(shift, scaled_inputs)