def _create_gates(self, inputs, memory): """Create input and forget gates for this step using `inputs` and `memory`. Args: inputs: Tensor input. memory: The current state of memory. Returns: input_gate: A LSTM-like insert gate. forget_gate: A LSTM-like forget gate. """ # We'll create the input and forget gates at once. Hence, calculate double # the gate size. num_gates = 2 * self._calculate_gate_size() memory = tf.tanh(memory) inputs = basic.BatchFlatten()(inputs) gate_inputs = basic.BatchApply(basic.Linear(num_gates), n_dims=1)(inputs) gate_inputs = tf.expand_dims(gate_inputs, axis=1) gate_memory = basic.BatchApply(basic.Linear(num_gates))(memory) gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2) input_gate, forget_gate = gates input_gate = tf.sigmoid(input_gate + self._input_bias) forget_gate = tf.sigmoid(forget_gate + self._forget_bias) return input_gate, forget_gate
def _build(self, inputs, memory, treat_input_as_matrix=False): """Adds relational memory to the TensorFlow graph. Args: inputs: Tensor input. memory: Memory output from the previous time step. treat_input_as_matrix: Optional, whether to treat `input` as a sequence of matrices. Defaulta to False, in which case the input is flattened into a vector. Returns: output: This time step's output. next_memory: The next version of memory to use. """ if treat_input_as_matrix: inputs = basic.BatchFlatten(preserve_dims=2)(inputs) inputs_reshape = basic.BatchApply(basic.Linear(self._mem_size), n_dims=2)(inputs) else: inputs = basic.BatchFlatten()(inputs) inputs = basic.Linear(self._mem_size)(inputs) inputs_reshape = tf.expand_dims(inputs, 1) memory_plus_input = tf.concat([memory, inputs_reshape], axis=1) next_memory = self._attend_over_memory(memory_plus_input) n = inputs_reshape.get_shape().as_list()[1] next_memory = next_memory[:, :-n, :] if self._gate_style == 'unit' or self._gate_style == 'memory': self._input_gate, self._forget_gate = self._create_gates( inputs_reshape, memory) next_memory = self._input_gate * tf.tanh(next_memory) next_memory += self._forget_gate * memory output = basic.BatchFlatten()(next_memory) return output, next_memory
def _multihead_attention(self, memory): """Perform multi-head attention from 'Attention is All You Need'. Implementation of the attention mechanism from https://arxiv.org/abs/1706.03762. Args: memory: Memory tensor to perform attention on. Returns: new_memory: New memory tensor. """ key_size = self._key_size value_size = self._head_size qkv_size = 2 * key_size + value_size total_size = qkv_size * self._num_heads # Denote as F. qkv = basic.BatchApply(basic.Linear(total_size))(memory) qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv) mem_slots = memory.get_shape().as_list()[1] # Denoted as N. # [B, N, F] -> [B, N, H, F/H] qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads, qkv_size])(qkv) # [B, N, H, F/H] -> [B, H, N, F/H] qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3]) q, k, v = tf.split(qkv_transpose, [key_size, key_size, value_size], -1) q *= qkv_size ** -0.5 dot_product = tf.matmul(q, k, transpose_b=True) # [B, H, N, N] weights = tf.nn.softmax(dot_product) output = tf.matmul(weights, v) # [B, H, N, V] # [B, H, N, V] -> [B, N, H, V] output_transpose = tf.transpose(output, [0, 2, 1, 3]) # [B, N, H, V] -> [B, N, H * V] new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose) return new_memory
def default_mlp(hidden_sizes, activate_final=False, init_std=2., **kwargs): """Standard batch-applied MLP for transformer modules.""" init = { 'w': tf.variance_scaling_initializer(init_std, distribution='normal') } mlp = snt_mlp.MLP(hidden_sizes, activate_final=activate_final, use_dropout=True, initializers=init, **kwargs) return basic.BatchApply(mlp)
def _attend_over_memory(self, memory): """Perform multiheaded attention over `memory`. Args: memory: Current relational memory. Returns: The attended-over memory. """ attention_mlp = basic.BatchApply( mlp.MLP([self._mem_size] * self._attention_mlp_layers)) for _ in range(self._num_blocks): attended_memory = self._multihead_attention(memory) # Add a skip connection to the multiheaded attention's input. memory = basic.BatchApply(layer_norm.LayerNorm())(memory + attended_memory) # Add a skip connection to the attention_mlp's input. memory = basic.BatchApply( layer_norm.LayerNorm())(attention_mlp(memory) + memory) return memory
def _build(self, inputs, state=None, condition=None, is_training=True, final_layer_key_value_inputs=None): """Calculates multi-layer self attention and mlp transformation. Args: inputs: Tensor of shape [batch_size, num_steps, dim_size]. state: optional list of length num_layers of tensors of shape [batch_size, memory_size, dim_size]. condition: optional tensor to condition on. The shape is shape [batch_size, dim_size]. is_training: If true, dropout is applied. final_layer_key_value_inputs: optional Tensor to be used as the key and value for the final multi-head attention layer of shape [batch_size, num_steps, dim_size]. Useful when the tower is a Seq2Seq decoder and it can attend to encoder outputs. Returns: output: tensor of shape [batch_size, num_steps, output_dim_size]. state: list of length `num_layers` containing AttentionState tuples. """ # inputs: [B, N, F] if final_layer_key_value_inputs is not None and state is not None and len( state) == (self._num_layers - 1): raise ValueError( 'When the final_layer_key_value_input is set, exclude' 'the state of the last layer.') if condition is not None: condition_tile = tf.tile(tf.expand_dims(condition, 1), [1, tf.shape(inputs)[1], 1]) inputs = tf.concat([inputs, condition_tile], -1) # Map inputs to be of `embedding_size` dimension. if inputs.get_shape().as_list()[-1] != self._embedding_size: inputs = default_mlp([self._embedding_size], activate_final=True)( inputs, is_training=is_training, dropout_keep_prob=1 - self._dropout_rate) if state is None: memory_sizes = [0] elif isinstance(state[0], CompressedMemoryState): cm_mem_size = max(_memory_size(s.compressed_memory) for s in state) em_mem_size = max(_memory_size(s.episodic_memory) for s in state) memory_sizes = [cm_mem_size, em_mem_size] else: memory_sizes = [max([_memory_size(s) for s in state])] chunk_size = inputs.get_shape().as_list()[1] self._positional_encodings = [] # Creates positional encodings for different memory types. for i, memory_size in enumerate(memory_sizes): seq_len = chunk_size + memory_size key_positions = get_position_encodings( sequence_length=seq_len, hidden_size=inputs.get_shape().as_list()[2], clamp_value=self._clamp_time_range, ) if is_training: key_positions = tf.nn.dropout(key_positions, rate=self._dropout_rate) key_positions = tf.cast(key_positions, dtype=inputs.dtype) query_positions = key_positions[:, -chunk_size:, :] self._positional_encodings.append((key_positions, query_positions)) if self._causal: self._mask = create_mask(inputs, state, self._same_attention_length) layer_i_inputs = inputs attention_states = [] key_value_inputs = None for i in range(self._num_layers): with tf.variable_scope('layer_%d' % i, reuse=tf.AUTO_REUSE): multihead_attention, object_mlp = self.get_sublayers( is_training) # Multihead attention with residuals. state_i = None if state is None else state[i] if i == (self._num_layers - 1) and final_layer_key_value_inputs is not None: # When the final_layer_key_value_inputs is set, the finaly layer # of attention will use it as the key & value, thus no need for state. key_value_inputs = final_layer_key_value_inputs state_i = None attention_outputs, attention_state = multihead_attention( layer_i_inputs, state=state_i, is_training=is_training, dropout_keep_prob=1. - self._dropout_rate, key_value_inputs=key_value_inputs) attention_states.append(attention_state) # Feed-forward with residuals. output = object_mlp(attention_outputs, is_training=is_training, dropout_keep_prob=1 - self._dropout_rate) layer_i_inputs = output if self._output_size is not None: output = basic.BatchApply( basic.Linear(self._output_size, use_bias=False))(output) return output, attention_states
def _build(self, inputs, query_inputs=None, state=None, is_training=False, dropout_keep_prob=0.5, key_value_inputs=None): """Calculates multi-layer self attention. Args: inputs: Tensor of shape [batch_size, num_steps, output_dim_size]. Inputs used as the query, key, and value to the attention layer. query_inputs: optional Tensor of shape [batch_size, num_steps, output_dim_size]. Query inputs to the attention layer. Set when query_inputs is different from the inputs argument. state: optional CompressedMemoryState or a Tensor of shape [batch_size, memory_size, dim_size] concatenated to the inputs. Set when attend to the memory from previous steps. is_training: if currently training. dropout_keep_prob: dropout rate applied to attention weights. key_value_inputs: optional Tensor of shape [batch_size, num_steps, output_dim_size]. It is used as the key and value of the multihead attention. Set when the key and value are different from the inputs argument. Returns: output: the result Tensor of shape [batch_size, num_steps, output_dim_size]. attention_state: named tuple of AttentionState. """ if key_value_inputs is not None and state is not None: raise ValueError( 'Only one of the key_value_input and state is needed.') embedding_size = self._value_size * self._num_heads q_inputs = inputs if query_inputs is None else query_inputs # Denoted by L. If query_inputs is None, L = N. _, query_size = q_inputs.get_shape().as_list()[:2] if key_value_inputs is not None: k_inputs = key_value_inputs v_inputs = k_inputs elif state is not None: if isinstance(state, CompressedMemoryState): state_memory_list = [ state.compressed_memory, state.episodic_memory ] else: state_memory_list = [state] k_inputs = tf.concat(state_memory_list + [inputs], 1) v_inputs = k_inputs else: k_inputs = inputs v_inputs = inputs # Batch size denoted by B batch_size = tf.shape(inputs)[0] # Chunk_size denoted by N chunk_size = inputs.get_shape().as_list()[1] # Denoted by N + M att_size = k_inputs.get_shape().as_list()[1] if self._positional_encodings and not self._use_relative_positions: if len(self._positional_encodings) != 1: raise ValueError( 'Absolute positional encodings only supported for 1 memory. ' 'Found %i.' % len(self._positional_encodings)) key_positions, query_positions = self._positional_encodings[0] k_inputs += key_positions q_inputs += query_positions # [B, H, L, K] q = self.multihead_linear(q_inputs, 'query') # [B, H, N + M, K] k = self.multihead_linear(k_inputs, 'key') # [B, H, N + M, V] v = self.multihead_linear(v_inputs, 'value') # Scaling the dot-product if self._scaling: q *= self._key_size**-0.5 # [B, H, L, N + M] if self._use_relative_positions: r_w_bias = tf.get_variable('r_w_bias', [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True) all_relative_logits = [] # Loop over multiple positional encodings, for the case of multiple # memory types. for i, positional_encodings in enumerate( self._positional_encodings): key_positions, query_positions = positional_encodings if key_positions.get_shape().as_list()[-1] != att_size: key_positions = key_positions[:, -att_size:] # Crop to layer mem size is_final = i == len(self._positional_encodings) - 1 suffix = '' if is_final else '_%d' % i relative_keys = self.multihead_linear(key_positions, name='relative_keys' + suffix) # [B, H, N, D] r_r_bias = tf.get_variable( 'r_r_bias' + suffix, [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1]) relative_logits = tf.matmul(q + r_r_bias, relative_keys, transpose_b=True) relative_logits = rel_shift(relative_logits) if not is_final: # Include relative positions for input sequence. relative_logits = relative_logits[:, :, :, :-chunk_size] all_relative_logits.append(relative_logits) all_relative_logits = tf.concat(all_relative_logits, 3) logits = content_logits + all_relative_logits else: # [B, H, N, N + M] logits = tf.matmul(q, k, transpose_b=True) content_logits = logits if self._mask is not None: if self._mask.get_shape().as_list()[-1] != att_size: mask = self._mask[:, :, :, -att_size:] else: mask = self._mask logits += mask weights = tf.nn.softmax(logits) if is_training: weights = tf.nn.dropout(weights, dropout_keep_prob) # [B, L, H, V], where V is value_size output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v) # [B, L, H, V] -> [B, L, HV] attended_inputs = basic.BatchReshape([query_size, embedding_size ])(output_transpose) # Apply final mlp to mix information between heads. output = basic.BatchApply( basic.Linear(embedding_size))(attended_inputs) attention_state = AttentionState(queries=q, keys=k, values=v, weights=weights, logits=content_logits, embeddings=inputs, read_words=output) return output, attention_state
def _layer_norm(inputs): if inputs.get_shape().ndims > 2: return basic.BatchApply(snt_ln.LayerNorm())(inputs) else: return snt_ln.LayerNorm()(inputs)
def _build(self, memory, query, memory_mask=None): """Perform a differentiable read. Args: memory: [batch_size, memory_size, memory_word_size]-shaped Tensor of dtype float32. This represents, for each example and memory slot, a single embedding to attend over. query: [batch_size, query_word_size]-shaped Tensor of dtype float32. Represents, for each example, a single embedding representing a query. memory_mask: None or [batch_size, memory_size]-shaped Tensor of dtype bool. An entry of False indicates that a memory slot should not enter the resulting weighted sum. If None, all memory is used. Returns: An AttentionOutput instance containing: read: [batch_size, memory_word_size]-shaped Tensor of dtype float32. This represents, for each example, a weighted sum of the contents of the memory. weights: [batch_size, memory_size]-shaped Tensor of dtype float32. This represents, for each example and memory slot, the attention weights used to compute the read. weight_logits: [batch_size, memory_size]-shaped Tensor of dtype float32. This represents, for each example and memory slot, the logits of the attention weights, that is, `weights` is calculated by taking the softmax of the weight logits. Raises: UnderspecifiedError: if memory_word_size or query_word_size can not be inferred. IncompatibleShapeError: if memory, query, memory_mask, or output of attention_logit_mod do not match expected shapes. """ if len(memory.get_shape()) != 3: raise base.IncompatibleShapeError( "memory must have shape [batch_size, memory_size, memory_word_size]." ) if len(query.get_shape()) != 2: raise base.IncompatibleShapeError( "query must have shape [batch_size, query_word_size].") if memory_mask is not None and len(memory_mask.get_shape()) != 2: raise base.IncompatibleShapeError( "memory_mask must have shape [batch_size, memory_size].") # Ensure final dimensions are defined, else the attention logit module will # be unable to infer input size when constructing variables. inferred_memory_word_size = memory.get_shape()[2].value inferred_query_word_size = query.get_shape()[1].value if inferred_memory_word_size is None or inferred_query_word_size is None: raise base.UnderspecifiedError( "memory_word_size and query_word_size must be known at graph " "construction time.") memory_shape = tf.shape(memory) batch_size = memory_shape[0] memory_size = memory_shape[1] query_shape = tf.shape(query) query_batch_size = query_shape[0] # Transform query to have same number of words as memory. # # expanded_query: [batch_size, memory_size, query_word_size]. expanded_query = tf.tile(tf.expand_dims(query, dim=1), [1, memory_size, 1]) # Compute attention weights for each memory slot. # # attention_weight_logits: [batch_size, memory_size] with tf.control_dependencies( [tf.assert_equal(batch_size, query_batch_size)]): concatenated_embeddings = tf.concat( values=[memory, expanded_query], axis=2) batch_apply_attention_logit = basic.BatchApply( self._attention_logit_mod, n_dims=2, name="batch_apply_attention_logit") attention_weight_logits = batch_apply_attention_logit( concatenated_embeddings) # Note: basic.BatchApply() will automatically reshape the [batch_size * # memory_size, 1]-shaped result of self._attention_logit_mod(...) into a # [batch_size, memory_size, 1]-shaped Tensor. If # self._attention_logit_mod(...) returns something with more dimensions, # then attention_weight_logits will have extra dimensions, too. if len(attention_weight_logits.get_shape()) != 3: raise base.IncompatibleShapeError( "attention_weight_logits must be a rank-3 Tensor. Are you sure that " "attention_logit_mod() returned [batch_size * memory_size, 1]-shaped" " Tensor?") # Remove final length-1 dimension. attention_weight_logits = tf.squeeze(attention_weight_logits, [2]) # Mask out ignored memory slots by assigning them very small logits. Ensures # that every example has at least one valid memory slot, else we'd end up # averaging all memory slots equally. if memory_mask is not None: num_remaining_memory_slots = tf.reduce_sum(tf.cast(memory_mask, dtype=tf.int32), axis=[1]) with tf.control_dependencies( [tf.assert_positive(num_remaining_memory_slots)]): finfo = np.finfo(np.float32) kept_indices = tf.cast(memory_mask, dtype=tf.float32) ignored_indices = tf.cast(tf.logical_not(memory_mask), dtype=tf.float32) lower_bound = finfo.max * kept_indices + finfo.min * ignored_indices attention_weight_logits = tf.minimum(attention_weight_logits, lower_bound) # attended_memory: [batch_size, memory_word_size]. attention_weight = tf.reshape(tf.nn.softmax(attention_weight_logits), shape=[batch_size, memory_size, 1]) # The multiplication is elementwise and relies on broadcasting the weights # across memory_word_size. Then we sum across the memory slots. attended_memory = tf.reduce_sum(memory * attention_weight, axis=[1]) # Infer shape of result as much as possible. inferred_batch_size, _, inferred_memory_word_size = ( memory.get_shape().as_list()) attended_memory.set_shape( [inferred_batch_size, inferred_memory_word_size]) return AttentionOutput(read=attended_memory, weights=tf.squeeze(attention_weight, [2]), weight_logits=attention_weight_logits)
def _build(self, inputs, state=None, condition=None, is_training=True): """Calculates multi-layer self attention and mlp transformation. Args: inputs: Tensor of shape [batch_size, num_steps, dim_size]. state: optional tensor of shape [batch_size, memory_size, dim_size]. condition: optional tensor to condition on. The shape is shape [batch_size, dim_size]. is_training: If true, dropout is applied. Returns: output: tensor of shape [batch_size, num_steps, output_dim_size]. state: list of length `num_layers` containing AttentionState tuples. """ # inputs: [B, N, F] if condition is not None: condition_tile = tf.tile(tf.expand_dims(condition, 1), [1, tf.shape(inputs)[1], 1]) inputs = tf.concat([inputs, condition_tile], -1) if state is None: memory_sizes = [0] elif isinstance(state[0], CompressedMemoryState): cm_mem_size = max(_memory_size(s.compressed_memory) for s in state) em_mem_size = max(_memory_size(s.episodic_memory) for s in state) memory_sizes = [cm_mem_size, em_mem_size] else: memory_sizes = [max([_memory_size(s) for s in state])] chunk_size = inputs.get_shape().as_list()[1] self._positional_encodings = [] # Creates positional encodings for different memory types. for i, memory_size in enumerate(memory_sizes): seq_len = chunk_size + memory_size key_positions = get_position_encodings( sequence_length=seq_len, hidden_size=inputs.get_shape().as_list()[2], clamp_value=self._clamp_time_range, ) if is_training: key_positions = tf.nn.dropout(key_positions, rate=self._dropout_rate) key_positions = tf.cast(key_positions, dtype=inputs.dtype) query_positions = key_positions[:, -chunk_size:, :] self._positional_encodings.append((key_positions, query_positions)) if self._causal: self._mask = create_mask(inputs, state, self._same_attention_length) layer_i_inputs = inputs attention_states = [] for i in range(self._num_layers): with tf.variable_scope('layer_%d' % i, reuse=tf.AUTO_REUSE): multihead_attention, object_mlp = self.get_sublayers( is_training) # Multihead attention with residuals. state_i = None if state is None else state[i] attention_outputs, attention_state = multihead_attention( layer_i_inputs, state=state_i, is_training=is_training, dropout_keep_prob=1. - self._dropout_rate) attention_states.append(attention_state) # Feed-forward with residuals. output = object_mlp(attention_outputs, is_training=is_training, dropout_keep_prob=1 - self._dropout_rate) layer_i_inputs = output if self._output_size is not None: output = basic.BatchApply( basic.Linear(self._output_size, use_bias=False))(output) return output, attention_states
def _build(self, inputs, query_inputs=None, state=None, is_training=False, dropout_keep_prob=0.5): embedding_size = self._value_size * self._num_heads q_inputs = inputs if query_inputs is None else query_inputs # Denoted by L. If query_inputs is None, L = N. _, query_size = q_inputs.get_shape().as_list()[:2] if state is not None: if isinstance(state, CompressedMemoryState): state_memory_list = [ state.compressed_memory, state.episodic_memory ] else: state_memory_list = [state] k_inputs = tf.concat(state_memory_list + [inputs], 1) v_inputs = k_inputs else: k_inputs = inputs v_inputs = inputs # Batch size denoted by B batch_size = tf.shape(inputs)[0] # Chunk_size denoted by N chunk_size = inputs.get_shape().as_list()[1] # Denoted by N + M att_size = k_inputs.get_shape().as_list()[1] if self._positional_encodings and not self._use_relative_positions: key_positions, query_positions = self._positional_encodings k_inputs += key_positions q_inputs += query_positions # [B, H, L, K] q = self.multihead_linear(q_inputs, 'query') # [B, H, N + M, K] k = self.multihead_linear(k_inputs, 'key') # [B, H, N + M, V] v = self.multihead_linear(v_inputs, 'value') # Scaling the dot-product if self._scaling: q *= self._key_size**-0.5 # [B, H, L, N + M] if self._use_relative_positions: r_w_bias = tf.get_variable('r_w_bias', [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True) all_relative_logits = [] # Loop over multiple positional encodings, for the case of multiple # memory types. for i, positional_encodings in enumerate( self._positional_encodings): key_positions, query_positions = positional_encodings if key_positions.get_shape().as_list()[-1] != att_size: key_positions = key_positions[:, -att_size:] # Crop to layer mem size is_final = i == len(self._positional_encodings) - 1 suffix = '' if is_final else '_%d' % i relative_keys = self.multihead_linear(key_positions, name='relative_keys' + suffix) # [B, H, N, D] r_r_bias = tf.get_variable( 'r_r_bias' + suffix, [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1]) relative_logits = tf.matmul(q + r_r_bias, relative_keys, transpose_b=True) relative_logits = rel_shift(relative_logits) if not is_final: # Include relative positions for input sequence. relative_logits = relative_logits[:, :, :, :-chunk_size] all_relative_logits.append(relative_logits) all_relative_logits = tf.concat(all_relative_logits, 3) logits = content_logits + all_relative_logits else: # [B, H, N, N + M] logits = tf.matmul(q, k, transpose_b=True) content_logits = logits if self._mask is not None: if self._mask.get_shape().as_list()[-1] != att_size: mask = self._mask[:, :, :, -att_size:] else: mask = self._mask logits += mask weights = tf.nn.softmax(logits) if is_training: weights = tf.nn.dropout(weights, dropout_keep_prob) # [B, L, H, V], where V is value_size output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v) # [B, L, H, V] -> [B, L, HV] attended_inputs = basic.BatchReshape([query_size, embedding_size ])(output_transpose) # Apply final mlp to mix information between heads. output = basic.BatchApply( basic.Linear(embedding_size))(attended_inputs) attention_state = AttentionState(queries=q, keys=k, values=v, weights=weights, logits=content_logits, embeddings=inputs, read_words=output) return output, attention_state