def _initial_call(self, new_inputs, length, **kwargs): inputs = new_inputs[..., tf.newaxis, :] batch_ndims = inputs.shape.ndims - 2 padded_inputs = tf.pad(inputs, paddings=[[0, 0]] * batch_ndims + [[0, length - 1], [0, 0]]) net = self.layer(padded_inputs, **kwargs) if net.shape[-1] == self.subset_K: #! loc = net loc = loc[..., 0:1, :] loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) # operate on subset: x = inputs x = tf.gather(x, self.shuffling, axis=-1) x1 = x[..., :self.subset_K] #x1 = super().call(x1) x1 = utils.one_hot_minus( x1, loc) #outputs = utils.one_hot_minus(inputs, loc) x2 = x[..., self.subset_K:] x = tf.concat([x1, x2], axis=-1) x = tf.gather(x, self.inverted_shuffling, axis=-1) outputs = x else: raise ValueError( 'Output of layer does not have compatible dimensions.') return outputs
def _per_timestep_call(self, current_outputs, new_inputs, length, timestep, **kwargs): """Returns Tensor of shape [..., timestep+1, vocab_size]. Args: current_outputs: Tensor of shape [..., timestep, vocab_size], the so-far generated sequence Tensor. new_inputs: Tensor of shape [..., vocab_size], the new input to generate its output given current_outputs. length: Length of final desired sequence. timestep: Current timestep. **kwargs: Optional keyword arguments to layer. """ inputs = tf.concat([current_outputs, new_inputs[..., tf.newaxis, :]], axis=-2) batch_ndims = inputs.shape.ndims - 2 padded_inputs = tf.pad(inputs, paddings=[[0, 0]] * batch_ndims + [[0, length - timestep - 1], [0, 0]]) net = self.layer(padded_inputs, **kwargs) if net.shape[-1] == self.vocab_size: loc = net loc = loc[..., :(timestep + 1), :] loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) new_outputs = utils.one_hot_minus(inputs, loc) else: raise ValueError( 'Output of layer does not have compatible dimensions.') outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]], axis=-2) if not tf.executing_eagerly(): outputs.set_shape([None] * batch_ndims + [timestep + 1, self.vocab_size]) return outputs
def sample_extm(self, n=1, *args, **kwargs): B = self.components.probs.shape[-2] dtype = self.dtype shift = utils.one_hot_argmax(self.logits, self.temperature) # N x B x K sample = self.components.sample(n) # n x N x B x K sample = utils.one_hot_minus(sample, shift) selected_flows = (np.arange(n) + np.random.randint(B) ) % B # allocate equally between all flows mask = tf.one_hot(selected_flows, B, dtype=dtype) mask = mask[:, None, :, None] # n x 1 x B x 1 sample = tf.reduce_sum(sample * mask, -2) # n x N x K return sample, mask
def test_one_hot_minus(self): """Test one_hot_minus (if max value position moves by -shift).""" K = 8 vals = np.array([1., 3., 7.]) shifts = np.array([7., 3., 1.]) sums = one_hot.one_hot_minus(tf.one_hot(vals, K), tf.one_hot(shifts, K)) # check if there's exactly one 1 per row and remaining are zeros: self.assertAllEqual((tf.reduce_sum(sums, -1)), 1, "row sum==1") self.assertAllEqual((tf.reduce_max(sums, -1)), 1, "max cell value in each row==1") self.assertAllEqual((tf.reduce_min(sums, -1)), 0, "min cell value in each row==0") # check if results are correct self.assertTrue((np.argmax(sums, -1) == (vals - shifts) % K).all(), "correct results")
def call(self, inputs, **kwargs): """Forward pass for bipartite generation.""" inputs = tf.convert_to_tensor(inputs) batch_ndims = inputs.shape.ndims - 2 mask = tf.reshape(tf.cast(self.mask, inputs.dtype), [1] * batch_ndims + [-1, 1]) masked_inputs = mask * inputs net = self.layer(masked_inputs, **kwargs) if net.shape[-1] == self.vocab_size: loc = net loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) masked_outputs = (1. - mask) * utils.one_hot_minus(inputs, loc) else: raise ValueError( 'Output of layer does not have compatible dimensions.') outputs = masked_inputs + masked_outputs return outputs
def _initial_call(self, new_inputs, length, **kwargs): """Returns Tensor of shape [..., 1, vocab_size]. Args: new_inputs: Tensor of shape [..., vocab_size], the new input to generate its output. length: Length of final desired sequence. **kwargs: Optional keyword arguments to layer. """ inputs = new_inputs[..., tf.newaxis, :] batch_ndims = inputs.shape.ndims - 2 padded_inputs = tf.pad(inputs, paddings=[[0, 0]] * batch_ndims + [[0, length - 1], [0, 0]]) net = self.layer(padded_inputs, **kwargs) if net.shape[-1] == self.vocab_size: loc = net loc = loc[..., 0:1, :] loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) outputs = utils.one_hot_minus(inputs, loc) else: raise ValueError( 'Output of layer does not have compatible dimensions.') return outputs
def _per_timestep_call(self, current_outputs, new_inputs, length, timestep, **kwargs): inputs = tf.concat([current_outputs, new_inputs[..., tf.newaxis, :]], axis=-2) batch_ndims = inputs.shape.ndims - 2 padded_inputs = tf.pad(inputs, paddings=[[0, 0]] * batch_ndims + [[0, length - timestep - 1], [0, 0]]) net = self.layer(padded_inputs, **kwargs) if net.shape[-1] == self.subset_K: #! loc = net loc = loc[..., :(timestep + 1), :] loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) # operate on subset: x = inputs x = tf.gather(x, self.shuffling, axis=-1) x1 = x[..., :self.subset_K] #x1 = super().call(x1) x1 = utils.one_hot_minus( x1, loc) #new_outputs = utils.one_hot_minus(inputs, loc) x2 = x[..., self.subset_K:] x = tf.concat([x1, x2], axis=-1) x = tf.gather(x, self.inverted_shuffling, axis=-1) new_outputs = x else: raise ValueError( 'Output of layer does not have compatible dimensions.') outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]], axis=-2) if not tf.executing_eagerly(): outputs.set_shape([None] * batch_ndims + [timestep + 1, self.vocab_size]) return outputs
def call(self, x): shift = utils.one_hot_argmax(self.logits, self.temperature) shifted_inputs = utils.one_hot_minus(x, shift) scale = utils.one_hot_argmax(self.logits_scale, self.temperature) inverse_scale = utils.multiplicative_inverse(scale, self.K) return utils.one_hot_multiply(shifted_inputs, inverse_scale)
def call_static(x, logits, temperature): shift = utils.one_hot_argmax(logits, temperature) return utils.one_hot_minus(x, shift)