def call(self, inputs, **kwargs): """Forward pass for bipartite generation.""" inputs = tf.convert_to_tensor(inputs) batch_ndims = inputs.shape.ndims - 2 mask = tf.reshape(tf.cast(self.mask, inputs.dtype), [1] * batch_ndims + [-1, 1]) masked_inputs = mask * inputs net = self.layer(masked_inputs, **kwargs) if net.shape[-1] == 2 * self.vocab_size: loc, scale = tf.split(net, 2, axis=-1) loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) scale = tf.cast(utils.one_hot_argmax(scale, self.temperature), inputs.dtype) inverse_scale = utils.multiplicative_inverse( scale, self.vocab_size) shifted_inputs = utils.one_hot_minus(inputs, loc) masked_outputs = (1. - mask) * utils.one_hot_multiply( shifted_inputs, inverse_scale) elif net.shape[-1] == self.vocab_size: loc = net loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) masked_outputs = (1. - mask) * utils.one_hot_minus(inputs, loc) else: raise ValueError( 'Output of layer does not have compatible dimensions.') outputs = masked_inputs + masked_outputs return outputs
def _initial_call(self, new_inputs, length, **kwargs): """Returns Tensor of shape [..., 1, vocab_size]. Args: new_inputs: Tensor of shape [..., vocab_size], the new input to generate its output. length: Length of final desired sequence. **kwargs: Optional keyword arguments to layer. """ inputs = new_inputs[..., tf.newaxis, :] # TODO(trandustin): To handle variable lengths, extend MADE to subset its # input and output layer weights rather than pad inputs. batch_ndims = inputs.shape.ndims - 2 padded_inputs = tf.pad( inputs, paddings=[[0, 0]] * batch_ndims + [[0, length - 1], [0, 0]]) net = self.layer(padded_inputs, **kwargs) if net.shape[-1] == 2 * self.vocab_size: loc, scale = tf.split(net, 2, axis=-1) loc = loc[..., 0:1, :] loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) scale = scale[..., 0:1, :] scale = tf.cast(utils.one_hot_argmax(scale, self.temperature), inputs.dtype) inverse_scale = utils.multiplicative_inverse(scale, self.vocab_size) shifted_inputs = utils.one_hot_minus(inputs, loc) outputs = utils.one_hot_multiply(shifted_inputs, inverse_scale) elif net.shape[-1] == self.vocab_size: loc = net loc = loc[..., 0:1, :] loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) outputs = utils.one_hot_minus(inputs, loc) else: raise ValueError('Output of layer does not have compatible dimensions.') return outputs
def _per_timestep_call(self, current_outputs, new_inputs, length, timestep, **kwargs): """Returns Tensor of shape [..., timestep+1, vocab_size]. Args: current_outputs: Tensor of shape [..., timestep, vocab_size], the so-far generated sequence Tensor. new_inputs: Tensor of shape [..., vocab_size], the new input to generate its output given current_outputs. length: Length of final desired sequence. timestep: Current timestep. **kwargs: Optional keyword arguments to layer. """ inputs = tf.concat([current_outputs, new_inputs[..., tf.newaxis, :]], axis=-2) # TODO(trandustin): To handle variable lengths, extend MADE to subset its # input and output layer weights rather than pad inputs. batch_ndims = inputs.shape.ndims - 2 padded_inputs = tf.pad(inputs, paddings=[[0, 0]] * batch_ndims + [[0, length - timestep - 1], [0, 0]]) net = self.layer(padded_inputs, **kwargs) if net.shape[-1] == 2 * self.vocab_size: loc, scale = tf.split(net, 2, axis=-1) loc = loc[..., :(timestep + 1), :] loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) scale = scale[..., :(timestep + 1), :] scale = tf.cast(utils.one_hot_argmax(scale, self.temperature), inputs.dtype) inverse_scale = utils.multiplicative_inverse( scale, self.vocab_size) shifted_inputs = utils.one_hot_minus(inputs, loc) new_outputs = utils.one_hot_multiply(shifted_inputs, inverse_scale) elif net.shape[-1] == self.vocab_size: loc = net loc = loc[..., :(timestep + 1), :] loc = tf.cast(utils.one_hot_argmax(loc, self.temperature), inputs.dtype) new_outputs = utils.one_hot_minus(inputs, loc) else: raise ValueError( 'Output of layer does not have compatible dimensions.') outputs = tf.concat([current_outputs, new_outputs[..., -1:, :]], axis=-2) if not tf.executing_eagerly(): outputs.set_shape([None] * batch_ndims + [timestep + 1, self.vocab_size]) return outputs