def reverse(self, inputs, **kwargs): """Reverse pass returning the inverse autoregressive transformation.""" if not self.built: self._maybe_build(inputs) net = self.layer(inputs, **kwargs) if net.shape[-1] == self.subset_K: #! loc = net else: raise ValueError( 'Output of layer does not have compatible dimensions.') loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) # operate on subset: x = inputs x = tf.gather(x, self.shuffling, axis=-1) x1 = x[..., :self.subset_K] #x1 = super().call(x1) x1 = utils.one_hot_add(loc, x1) x2 = x[..., self.subset_K:] x = tf.concat([x1, x2], axis=-1) x = tf.gather(x, self.inverted_shuffling, axis=-1) outputs = x return outputs
def log_prob(self, sample, eps_prob=1e-31): B = self.components.probs.shape[-2] component_probs = self.components.probs shift = utils.one_hot_argmax(self.logits, self.temperature) sample = utils.one_hot_add(sample[:, :, None, :], shift[None, :, :, :]) prob = tf.reduce_sum(component_probs * sample + eps_prob, -1) # sum over categories => n x N x B log_prob = tf.math.log(prob) + np.log(1. / B) # n x N x B log_prob = tf.math.reduce_logsumexp( log_prob, -1) # sum over B mixture components => n x N return tf.reduce_sum(log_prob, -1) # sum over N
def test_one_hot_add(self): """Test one_hot_add (if max value position moves by shift).""" K = 8 vals = np.array([1., 3., 7.]) shifts = np.array([7., 3., 1.]) sums = one_hot.one_hot_add(tf.one_hot(vals, K), tf.one_hot(shifts, K)) # check if there's exactly one 1 per row and remaining are zeros: self.assertAllEqual((tf.reduce_sum(sums, -1)), 1, "row sum==1") self.assertAllEqual((tf.reduce_max(sums, -1)), 1, "max cell value in each row==1") self.assertAllEqual((tf.reduce_min(sums, -1)), 0, "min cell value in each row==0") # check if results are correct self.assertTrue((np.argmax(sums, -1) == (vals + shifts) % K).all(), "correct results")
def reverse(self, inputs, **kwargs): """Reverse pass returning the inverse autoregressive transformation.""" if not self.built: self._maybe_build(inputs) net = self.layer(inputs, **kwargs) if net.shape[-1] == self.vocab_size: loc = net scaled_inputs = inputs else: raise ValueError( 'Output of layer does not have compatible dimensions.') loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) outputs = utils.one_hot_add(loc, scaled_inputs) return outputs
def reverse(self, inputs, **kwargs): """Reverse pass for the inverse bipartite transformation.""" if not self.built: self._maybe_build(inputs) inputs = tf.convert_to_tensor(inputs) batch_ndims = inputs.shape.ndims - 2 mask = tf.reshape(tf.cast(self.mask, inputs.dtype), [1] * batch_ndims + [-1, 1]) masked_inputs = mask * inputs net = self.layer(masked_inputs, **kwargs) if net.shape[-1] == self.vocab_size: loc = net scaled_inputs = inputs else: raise ValueError( 'Output of layer does not have compatible dimensions.') loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) masked_outputs = (1. - mask) * utils.one_hot_add(loc, scaled_inputs) outputs = masked_inputs + masked_outputs return outputs
def reverse_static(x, logits, temperature): shift = utils.one_hot_argmax(logits, temperature) return utils.one_hot_add(x, shift)
def reverse(self, x): scale = utils.one_hot_argmax(self.logits_scale, self.temperature) scaled_inputs = utils.one_hot_multiply(x, scale) shift = utils.one_hot_argmax(self.logits, self.temperature) return utils.one_hot_add(shift, scaled_inputs)