def _attend_over_memory(self, memory): """Perform multiheaded attention over `memory`. Args: memory: Current relational memory. Returns: The attended-over memory. """ attention_mlp = basic.BatchApply( mlp.MLP([self._mem_size] * self._attention_mlp_layers)) for _ in range(self._num_blocks): attended_memory = self._multihead_attention(memory) # Add a skip connection to the multiheaded attention's input. memory = basic.BatchApply(layer_norm.LayerNorm())(memory + attended_memory) # Add a skip connection to the attention_mlp's input. memory = basic.BatchApply( layer_norm.LayerNorm())(attention_mlp(memory) + memory) return memory
def _multihead_attention(self, memory): """Perform multi-head attention from 'Attention is All You Need'. Implementation of the attention mechanism from https://arxiv.org/abs/1706.03762. Args: memory: Memory tensor to perform attention on. Returns: new_memory: New memory tensor. """ key_size = self._key_size value_size = self._head_size qkv_size = 2 * key_size + value_size total_size = qkv_size * self._num_heads # Denote as F. qkv = basic.BatchApply(basic.Linear(total_size))(memory) qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv) mem_slots = memory.get_shape().as_list()[1] # Denoted as N. # [B, N, F] -> [B, N, H, F/H] qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads, qkv_size])(qkv) # [B, N, H, F/H] -> [B, H, N, F/H] qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3]) q, k, v = tf.split(qkv_transpose, [key_size, key_size, value_size], -1) q *= qkv_size ** -0.5 dot_product = tf.matmul(q, k, transpose_b=True) # [B, H, N, N] weights = tf.nn.softmax(dot_product) output = tf.matmul(weights, v) # [B, H, N, V] # [B, H, N, V] -> [B, N, H, V] output_transpose = tf.transpose(output, [0, 2, 1, 3]) # [B, N, H, V] -> [B, N, H * V] new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose) return new_memory
def _build(self, inputs, prev_state, is_training=True, test_local_stats=True): """Connects the LSTM module into the graph. If this is not the first time the module has been connected to the graph, the Tensors provided as inputs and state must have the same final dimension, in order for the existing variables to be the correct size for their corresponding multiplications. The batch size may differ for each connection. Args: inputs: Tensor of size `[batch_size, input_size]`. prev_state: Tuple (prev_hidden, prev_cell), or if batch norm is enabled and `max_unique_stats > 1`, then (prev_hidden, prev_cell, time_step). Here, prev_hidden and prev_cell are tensors of size `[batch_size, hidden_size]`, and time_step is used to indicate the current RNN step. is_training: Boolean indicating whether we are in training mode (as opposed to testing mode), passed to the batch norm modules. Note to use this you must wrap the cell via the `with_batch_norm_control` function. test_local_stats: Boolean indicating whether to use local batch statistics in test mode. See the `BatchNorm` documentation for more on this. Returns: A tuple (output, next_state) where 'output' is a Tensor of size `[batch_size, hidden_size]` and 'next_state' is a tuple (next_hidden, next_cell) or (next_hidden, next_cell, time_step + 1), where next_hidden and next_cell have size `[batch_size, hidden_size]`. Raises: ValueError: If connecting the module into the graph any time after the first time, and the inferred size of the inputs does not match previous invocations. """ if self._max_unique_stats == 1: prev_hidden, prev_cell = prev_state time_step = None else: prev_hidden, prev_cell, time_step = prev_state self._create_gate_variables(inputs.get_shape(), inputs.dtype) self._create_batch_norm_variables(inputs.dtype) # pylint false positive: calling module of same file; # pylint: disable=not-callable if self._use_batch_norm_h or self._use_batch_norm_x: gates_h = tf.matmul(prev_hidden, self._w_h) gates_x = tf.matmul(inputs, self._w_x) if self._use_batch_norm_h: gates_h = self._gamma_h * self._batch_norm_h(gates_h, time_step, is_training, test_local_stats) if self._use_batch_norm_x: gates_x = self._gamma_x * self._batch_norm_x(gates_x, time_step, is_training, test_local_stats) gates = gates_h + gates_x else: # Parameters of gates are concatenated into one multiply for efficiency. inputs_and_hidden = tf.concat([inputs, prev_hidden], 1) gates = tf.matmul(inputs_and_hidden, self._w_xh) if self._use_layer_norm: gates = layer_norm.LayerNorm()(gates) gates += self._b # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=gates, num_or_size_splits=4, axis=1) if self._use_peepholes: # diagonal connections self._create_peephole_variables(inputs.dtype) f += self._w_f_diag * prev_cell i += self._w_i_diag * prev_cell forget_mask = tf.sigmoid(f + self._forget_bias) new_cell = forget_mask * prev_cell + tf.sigmoid(i) * tf.tanh(j) cell_output = new_cell if self._use_batch_norm_c: cell_output = (self._beta_c + self._gamma_c * self._batch_norm_c(cell_output, time_step, is_training, test_local_stats)) if self._use_peepholes: cell_output += self._w_o_diag * cell_output new_hidden = tf.tanh(cell_output) * tf.sigmoid(o) if self._max_unique_stats == 1: return new_hidden, (new_hidden, new_cell) else: return new_hidden, (new_hidden, new_cell, time_step + 1)
def _layer_norm(inputs): if inputs.get_shape().ndims > 2: return basic.BatchApply(snt_ln.LayerNorm())(inputs) else: return snt_ln.LayerNorm()(inputs)