def _multihead_attention(self, memory): """Perform multi-head attention from 'Attention is All You Need'. Implementation of the attention mechanism from https://arxiv.org/abs/1706.03762. Args: memory: Memory tensor to perform attention on. Returns: new_memory: New memory tensor. """ key_size = self._key_size value_size = self._head_size qkv_size = 2 * key_size + value_size total_size = qkv_size * self._num_heads # Denote as F. qkv = basic.BatchApply(basic.Linear(total_size))(memory) qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv) mem_slots = memory.get_shape().as_list()[1] # Denoted as N. # [B, N, F] -> [B, N, H, F/H] qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads, qkv_size])(qkv) # [B, N, H, F/H] -> [B, H, N, F/H] qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3]) q, k, v = tf.split(qkv_transpose, [key_size, key_size, value_size], -1) q *= qkv_size ** -0.5 dot_product = tf.matmul(q, k, transpose_b=True) # [B, H, N, N] weights = tf.nn.softmax(dot_product) output = tf.matmul(weights, v) # [B, H, N, V] # [B, H, N, V] -> [B, N, H, V] output_transpose = tf.transpose(output, [0, 2, 1, 3]) # [B, N, H, V] -> [B, N, H * V] new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose) return new_memory
def _affine_grid_warper_inverse(inputs): """Assembles network to compute inverse affine transformation. Each `inputs` row potentailly contains [a, b, tx, c, d, ty] corresponding to an affine matrix: A = [a, b, tx], [c, d, ty] We want to generate a tensor containing the coefficients of the corresponding inverse affine transformation in a constraints-aware fashion. Calling M: M = [a, b] [c, d] the affine matrix for the inverse transform is: A_in = [M^(-1), M^-1 * [-tx, -tx]^T] where M^(-1) = (ad - bc)^(-1) * [ d, -b] [-c, a] Args: inputs: Tensor containing a batch of transformation parameters. Returns: A tensorflow graph performing the inverse affine transformation parametrized by the input coefficients. """ batch_size = tf.expand_dims(tf.shape(inputs)[0], 0) constant_shape = tf.concat([batch_size, tf.convert_to_tensor((1,))], 0) index = iter(range(6)) def get_variable(constraint): if constraint is None: i = index.next() return inputs[:, i:i+1] else: return tf.fill(constant_shape, tf.constant(constraint, dtype=inputs.dtype)) constraints = chain.from_iterable(self.constraints) a, b, tx, c, d, ty = (get_variable(constr) for constr in constraints) det = a * d - b * c a_inv = d / det b_inv = -b / det c_inv = -c / det d_inv = a / det m_inv = basic.BatchReshape( [2, 2])(tf.concat([a_inv, b_inv, c_inv, d_inv], 1)) txy = tf.expand_dims(tf.concat([tx, ty], 1), 2) txy_inv = basic.BatchFlatten()(tf.matmul(m_inv, txy)) tx_inv = txy_inv[:, 0:1] ty_inv = txy_inv[:, 1:2] inverse_gw_inputs = tf.concat( [a_inv, b_inv, -tx_inv, c_inv, d_inv, -ty_inv], 1) agw = AffineGridWarper(self.output_shape, self.source_shape) return agw(inverse_gw_inputs) # pylint: disable=not-callable
def _build(self, inputs): """Assembles the module network and adds it to the graph. The internal computation graph is assembled according to the set of constraints provided at construction time. Args: inputs: Tensor containing a batch of transformation parameters. Returns: A batch of warped grids. Raises: Error: If the input tensor size is not consistent with the constraints passed at construction time. """ input_shape = tf.shape(inputs) input_dtype = inputs.dtype.as_numpy_dtype batch_size = tf.expand_dims(input_shape[0], 0) number_of_params = inputs.get_shape()[1] if number_of_params != self._constraints.num_free_params: raise base.Error('Input size is not consistent with constraint ' 'definition: {} parameters expected, {} provided.' .format(self._constraints.num_free_params, number_of_params)) num_output_dimensions = len(self._psi) // 3 def get_input_slice(start, size): """Extracts a subset of columns from the input 2D Tensor.""" return basic.SliceByDim([1], [start], [size])(inputs) warped_grid = [] var_index_offset = 0 number_of_points = np.prod(self._output_shape) for i in xrange(num_output_dimensions): if self._psi[i] is not None: # The i-th output dimension is not fully specified by the constraints, # the graph is setup to perform matrix multiplication in batch mode. grid_coord = self._psi[i].astype(input_dtype) num_active_vars = self._psi[i].shape[0] active_vars = get_input_slice(var_index_offset, num_active_vars) warped_coord = tf.matmul(active_vars, grid_coord) warped_coord = tf.expand_dims(warped_coord, 1) var_index_offset += num_active_vars offset = self._psi[num_output_dimensions + i] if offset is not None: offset = offset.astype(input_dtype) # Some entries in the i-th row of the affine matrix were constrained # and the corresponding matrix multiplications have been precomputed. tiling_params = tf.concat( [ batch_size, tf.constant( 1, shape=(1,)), tf.ones_like(offset.shape) ], 0) offset = offset.reshape((1, 1) + offset.shape) warped_coord += tf.tile(offset, tiling_params) else: # The i-th output dimension is fully specified by the constraints, and # the corresponding matrix multiplications have been precomputed. warped_coord = self._psi[num_output_dimensions + i].astype(input_dtype) tiling_params = tf.concat( [ batch_size, tf.constant( 1, shape=(1,)), tf.ones_like(warped_coord.shape) ], 0) warped_coord = warped_coord.reshape((1, 1) + warped_coord.shape) warped_coord = tf.tile(warped_coord, tiling_params) warped_coord += self._psi[i + 2 * num_output_dimensions] # Need to help TF figuring out shape inference since tiling information # is held in Tensors which are not known until run time. warped_coord.set_shape([None, 1, number_of_points]) warped_grid.append(warped_coord) # Reshape all the warped coordinates tensors to match the specified output # shape and concatenate into a single matrix. grid_shape = self._output_shape + (1,) warped_grid = [basic.BatchReshape(grid_shape)(grid) for grid in warped_grid] return tf.concat(warped_grid, len(grid_shape))
def _build(self, inputs, query_inputs=None, state=None, is_training=False, dropout_keep_prob=0.5, key_value_inputs=None): """Calculates multi-layer self attention. Args: inputs: Tensor of shape [batch_size, num_steps, output_dim_size]. Inputs used as the query, key, and value to the attention layer. query_inputs: optional Tensor of shape [batch_size, num_steps, output_dim_size]. Query inputs to the attention layer. Set when query_inputs is different from the inputs argument. state: optional CompressedMemoryState or a Tensor of shape [batch_size, memory_size, dim_size] concatenated to the inputs. Set when attend to the memory from previous steps. is_training: if currently training. dropout_keep_prob: dropout rate applied to attention weights. key_value_inputs: optional Tensor of shape [batch_size, num_steps, output_dim_size]. It is used as the key and value of the multihead attention. Set when the key and value are different from the inputs argument. Returns: output: the result Tensor of shape [batch_size, num_steps, output_dim_size]. attention_state: named tuple of AttentionState. """ if key_value_inputs is not None and state is not None: raise ValueError( 'Only one of the key_value_input and state is needed.') embedding_size = self._value_size * self._num_heads q_inputs = inputs if query_inputs is None else query_inputs # Denoted by L. If query_inputs is None, L = N. _, query_size = q_inputs.get_shape().as_list()[:2] if key_value_inputs is not None: k_inputs = key_value_inputs v_inputs = k_inputs elif state is not None: if isinstance(state, CompressedMemoryState): state_memory_list = [ state.compressed_memory, state.episodic_memory ] else: state_memory_list = [state] k_inputs = tf.concat(state_memory_list + [inputs], 1) v_inputs = k_inputs else: k_inputs = inputs v_inputs = inputs # Batch size denoted by B batch_size = tf.shape(inputs)[0] # Chunk_size denoted by N chunk_size = inputs.get_shape().as_list()[1] # Denoted by N + M att_size = k_inputs.get_shape().as_list()[1] if self._positional_encodings and not self._use_relative_positions: if len(self._positional_encodings) != 1: raise ValueError( 'Absolute positional encodings only supported for 1 memory. ' 'Found %i.' % len(self._positional_encodings)) key_positions, query_positions = self._positional_encodings[0] k_inputs += key_positions q_inputs += query_positions # [B, H, L, K] q = self.multihead_linear(q_inputs, 'query') # [B, H, N + M, K] k = self.multihead_linear(k_inputs, 'key') # [B, H, N + M, V] v = self.multihead_linear(v_inputs, 'value') # Scaling the dot-product if self._scaling: q *= self._key_size**-0.5 # [B, H, L, N + M] if self._use_relative_positions: r_w_bias = tf.get_variable('r_w_bias', [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True) all_relative_logits = [] # Loop over multiple positional encodings, for the case of multiple # memory types. for i, positional_encodings in enumerate( self._positional_encodings): key_positions, query_positions = positional_encodings if key_positions.get_shape().as_list()[-1] != att_size: key_positions = key_positions[:, -att_size:] # Crop to layer mem size is_final = i == len(self._positional_encodings) - 1 suffix = '' if is_final else '_%d' % i relative_keys = self.multihead_linear(key_positions, name='relative_keys' + suffix) # [B, H, N, D] r_r_bias = tf.get_variable( 'r_r_bias' + suffix, [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1]) relative_logits = tf.matmul(q + r_r_bias, relative_keys, transpose_b=True) relative_logits = rel_shift(relative_logits) if not is_final: # Include relative positions for input sequence. relative_logits = relative_logits[:, :, :, :-chunk_size] all_relative_logits.append(relative_logits) all_relative_logits = tf.concat(all_relative_logits, 3) logits = content_logits + all_relative_logits else: # [B, H, N, N + M] logits = tf.matmul(q, k, transpose_b=True) content_logits = logits if self._mask is not None: if self._mask.get_shape().as_list()[-1] != att_size: mask = self._mask[:, :, :, -att_size:] else: mask = self._mask logits += mask weights = tf.nn.softmax(logits) if is_training: weights = tf.nn.dropout(weights, dropout_keep_prob) # [B, L, H, V], where V is value_size output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v) # [B, L, H, V] -> [B, L, HV] attended_inputs = basic.BatchReshape([query_size, embedding_size ])(output_transpose) # Apply final mlp to mix information between heads. output = basic.BatchApply( basic.Linear(embedding_size))(attended_inputs) attention_state = AttentionState(queries=q, keys=k, values=v, weights=weights, logits=content_logits, embeddings=inputs, read_words=output) return output, attention_state
def _build(self, inputs, query_inputs=None, state=None, is_training=False, dropout_keep_prob=0.5): embedding_size = self._value_size * self._num_heads q_inputs = inputs if query_inputs is None else query_inputs # Denoted by L. If query_inputs is None, L = N. _, query_size = q_inputs.get_shape().as_list()[:2] if state is not None: if isinstance(state, CompressedMemoryState): state_memory_list = [ state.compressed_memory, state.episodic_memory ] else: state_memory_list = [state] k_inputs = tf.concat(state_memory_list + [inputs], 1) v_inputs = k_inputs else: k_inputs = inputs v_inputs = inputs # Batch size denoted by B batch_size = tf.shape(inputs)[0] # Chunk_size denoted by N chunk_size = inputs.get_shape().as_list()[1] # Denoted by N + M att_size = k_inputs.get_shape().as_list()[1] if self._positional_encodings and not self._use_relative_positions: key_positions, query_positions = self._positional_encodings k_inputs += key_positions q_inputs += query_positions # [B, H, L, K] q = self.multihead_linear(q_inputs, 'query') # [B, H, N + M, K] k = self.multihead_linear(k_inputs, 'key') # [B, H, N + M, V] v = self.multihead_linear(v_inputs, 'value') # Scaling the dot-product if self._scaling: q *= self._key_size**-0.5 # [B, H, L, N + M] if self._use_relative_positions: r_w_bias = tf.get_variable('r_w_bias', [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True) all_relative_logits = [] # Loop over multiple positional encodings, for the case of multiple # memory types. for i, positional_encodings in enumerate( self._positional_encodings): key_positions, query_positions = positional_encodings if key_positions.get_shape().as_list()[-1] != att_size: key_positions = key_positions[:, -att_size:] # Crop to layer mem size is_final = i == len(self._positional_encodings) - 1 suffix = '' if is_final else '_%d' % i relative_keys = self.multihead_linear(key_positions, name='relative_keys' + suffix) # [B, H, N, D] r_r_bias = tf.get_variable( 'r_r_bias' + suffix, [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1]) relative_logits = tf.matmul(q + r_r_bias, relative_keys, transpose_b=True) relative_logits = rel_shift(relative_logits) if not is_final: # Include relative positions for input sequence. relative_logits = relative_logits[:, :, :, :-chunk_size] all_relative_logits.append(relative_logits) all_relative_logits = tf.concat(all_relative_logits, 3) logits = content_logits + all_relative_logits else: # [B, H, N, N + M] logits = tf.matmul(q, k, transpose_b=True) content_logits = logits if self._mask is not None: if self._mask.get_shape().as_list()[-1] != att_size: mask = self._mask[:, :, :, -att_size:] else: mask = self._mask logits += mask weights = tf.nn.softmax(logits) if is_training: weights = tf.nn.dropout(weights, dropout_keep_prob) # [B, L, H, V], where V is value_size output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v) # [B, L, H, V] -> [B, L, HV] attended_inputs = basic.BatchReshape([query_size, embedding_size ])(output_transpose) # Apply final mlp to mix information between heads. output = basic.BatchApply( basic.Linear(embedding_size))(attended_inputs) attention_state = AttentionState(queries=q, keys=k, values=v, weights=weights, logits=content_logits, embeddings=inputs, read_words=output) return output, attention_state