def _create_gates(self, inputs, memory): """Create input and forget gates for this step using `inputs` and `memory`. Args: inputs: Tensor input. memory: The current state of memory. Returns: input_gate: A LSTM-like insert gate. forget_gate: A LSTM-like forget gate. """ # We'll create the input and forget gates at once. Hence, calculate double # the gate size. num_gates = 2 * self._calculate_gate_size() memory = tf.tanh(memory) inputs = basic.BatchFlatten()(inputs) gate_inputs = basic.BatchApply(basic.Linear(num_gates), n_dims=1)(inputs) gate_inputs = tf.expand_dims(gate_inputs, axis=1) gate_memory = basic.BatchApply(basic.Linear(num_gates))(memory) gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2) input_gate, forget_gate = gates input_gate = tf.sigmoid(input_gate + self._input_bias) forget_gate = tf.sigmoid(forget_gate + self._forget_bias) return input_gate, forget_gate
def _build(self, inputs, memory, treat_input_as_matrix=False): """Adds relational memory to the TensorFlow graph. Args: inputs: Tensor input. memory: Memory output from the previous time step. treat_input_as_matrix: Optional, whether to treat `input` as a sequence of matrices. Defaulta to False, in which case the input is flattened into a vector. Returns: output: This time step's output. next_memory: The next version of memory to use. """ if treat_input_as_matrix: inputs = basic.BatchFlatten(preserve_dims=2)(inputs) inputs_reshape = basic.BatchApply(basic.Linear(self._mem_size), n_dims=2)(inputs) else: inputs = basic.BatchFlatten()(inputs) inputs = basic.Linear(self._mem_size)(inputs) inputs_reshape = tf.expand_dims(inputs, 1) memory_plus_input = tf.concat([memory, inputs_reshape], axis=1) next_memory = self._attend_over_memory(memory_plus_input) n = inputs_reshape.get_shape().as_list()[1] next_memory = next_memory[:, :-n, :] if self._gate_style == 'unit' or self._gate_style == 'memory': self._input_gate, self._forget_gate = self._create_gates( inputs_reshape, memory) next_memory = self._input_gate * tf.tanh(next_memory) next_memory += self._forget_gate * memory output = basic.BatchFlatten()(next_memory) return output, next_memory
def _build(self, x, prev_state): """Connects the core to the graph. Args: x: Input `Tensor` of shape `(batch_size, input_size)`. prev_state: Previous state. This could be a `Tensor`, or a tuple of `Tensor`s. Returns: The tuple `(output, state)` for this core. Raises: ValueError: if the `Tensor` `x` does not have rank 2. """ x.get_shape().with_rank(2) self._batch_size = x.get_shape().as_list()[0] self._dtype = x.dtype x_zeros = tf.concat( [x, tf.zeros(shape=(self._batch_size, 1), dtype=self._dtype)], 1) x_ones = tf.concat( [x, tf.ones(shape=(self._batch_size, 1), dtype=self._dtype)], 1) # Weights for the halting signal halting_linear = basic.Linear(name="halting_linear", output_size=1) body = functools.partial(self._body, halting_linear=halting_linear, x_ones=x_ones) cumul_halting_init = tf.zeros(shape=(self._batch_size, 1), dtype=self._dtype) iteration_init = tf.zeros(shape=(self._batch_size, 1), dtype=self._dtype) core_output_size = [x.value for x in self._core.output_size] out_init = tf.zeros(shape=(self._batch_size, ) + tuple(core_output_size), dtype=self._dtype) cumul_state_init = _nested_zeros_like(prev_state) remainder_init = tf.zeros(shape=(self._batch_size, 1), dtype=self._dtype) (unused_final_x, final_out, unused_final_state, final_cumul_state, unused_final_halting, final_iteration, final_remainder) = tf.while_loop(self._cond, body, [ x_zeros, out_init, prev_state, cumul_state_init, cumul_halting_init, iteration_init, remainder_init ]) act_output = basic.Linear(name="act_output_linear", output_size=self._output_size)(final_out) # 修改,控制器state和读向量使用 pondering 累加权重系数方式, # 记忆矩阵不使用,记忆矩阵保持展开时间连续性 controller_state, access_state, read_vectors = final_cumul_state final_cumul_state = (controller_state, unused_final_state[1], read_vectors) return (act_output, (final_iteration, final_remainder)), final_cumul_state
def _instantiate_layers(self): """Instantiates all the linear modules used in the network. Layers are instantiated in the constructor, as opposed to the build function, because MLP implements the Transposable interface, and the transpose function can be called before the module is actually connected to the graph and build is called. Notice that this is safe since layers in the transposed module are instantiated using a lambda returning input_size of the mlp layers, and this doesn't have to return sensible values until the original module is connected to the graph. """ # Here we are entering the module's variable scope to name our submodules # correctly (not to create variables). As such it's safe to not check # whether we're in the same graph. This is important if we're constructing # the module in one graph and connecting it in another (e.g. with `defun` # the module is created in some default graph, and connected to a capturing # graph in order to turn it into a graph function). with self._enter_variable_scope(check_same_graph=False): self._layers = [ basic.Linear(self._output_sizes[i], name="linear_{}".format(i), initializers=self._initializers, partitioners=self._partitioners, regularizers=self._regularizers, use_bias=self.use_bias) for i in xrange(self._num_layers) ]
def _instantiate_layers(self): """Instantiates all the linear modules used in the network. Layers are instantiated in the constructor, as opposed to the build function, because MLP implements the Transposable interface, and the transpose function can be called before the module is actually connected to the graph and build is called. Notice that this is safe since layers in the transposed module are instantiated using a lambda returning input_size of the mlp layers, and this doesn't have to return sensible values until the original module is connected to the graph. """ with self._enter_variable_scope(): self._layers = [ basic.Linear( self._output_sizes[i], name="linear_{}".format(i), initializers=self._initializers, partitioners=self._partitioners, regularizers=self._regularizers, use_bias=self.use_bias, ) for i in xrange(self._num_layers) ]
def _build(self, input_, prev_state): """Connects the VanillaRNN module into the graph. If this is not the first time the module has been connected to the graph, the Tensors provided as input_ and state must have the same final dimension, in order for the existing variables to be the correct size for their corresponding multiplications. The batch size may differ for each connection. Args: input_: a 2D Tensor of size [batch_size, input_size]. prev_state: a 2D Tensor of size [batch_size, hidden_size]. Returns: output: a 2D Tensor of size [batch_size, hidden_size]. next_state: a Tensor of size [batch_size, hidden_size]. Raises: ValueError: if connecting the module into the graph any time after the first time, and the inferred size of the inputs does not match previous invocations. """ self._in_to_hidden_linear = basic.Linear( self._hidden_size, name="in_to_hidden", initializers=self._initializers.get("in_to_hidden"), partitioners=self._partitioners.get("in_to_hidden"), regularizers=self._regularizers.get("in_to_hidden"), ) self._hidden_to_hidden_linear = basic.Linear( self._hidden_size, name="hidden_to_hidden", initializers=self._initializers.get("hidden_to_hidden"), partitioners=self._partitioners.get("hidden_to_hidden"), regularizers=self._regularizers.get("hidden_to_hidden"), ) in_to_hidden = self._in_to_hidden_linear(input_) hidden_to_hidden = self._hidden_to_hidden_linear(prev_state) output = self._activation(in_to_hidden + hidden_to_hidden) # For VanillaRNN, the next state of the RNN is the same as the output return output, output
def testModuleInfo_multiple_modules(self): # pylint: disable=not-callable tf.reset_default_graph() dumb = DumbModule(name="dumb") dumb_1 = DumbModule(name="dumb") linear = basic.Linear(10, name="linear") ph_0 = tf.placeholder(dtype=tf.float32, shape=( 1, 10, )) dumb(ph_0) with tf.name_scope("foo"): dumb_1(ph_0) linear(ph_0) def check(): sonnet_collection = tf.get_default_graph().get_collection( base_info.SONNET_COLLECTION_NAME) self.assertEqual(len(sonnet_collection), 3) # item 0. self.assertEqual(sonnet_collection[0].module_name, "dumb") self.assertEqual(sonnet_collection[0].class_name, "{}.DumbModule".format(THIS_MODULE)) self.assertEqual(sonnet_collection[0].scope_name, "dumb") self.assertEqual(len(sonnet_collection[0].connected_subgraphs), 1) self.assertEqual( sonnet_collection[0].connected_subgraphs[0].name_scope, "dumb") # item 1. self.assertEqual(sonnet_collection[1].module_name, "dumb_1") self.assertEqual(sonnet_collection[1].scope_name, "dumb_1") self.assertEqual(sonnet_collection[1].class_name, "{}.DumbModule".format(THIS_MODULE)) self.assertEqual(sonnet_collection[1].scope_name, "dumb_1") self.assertEqual(len(sonnet_collection[1].connected_subgraphs), 1) self.assertEqual( sonnet_collection[1].connected_subgraphs[0].name_scope, "foo/dumb_1") # item 2. self.assertEqual(sonnet_collection[2].module_name, "linear") self.assertEqual(sonnet_collection[2].scope_name, "linear") self.assertEqual(sonnet_collection[2].class_name, "{}.Linear".format(LINEAR_MODULE)) self.assertEqual(sonnet_collection[2].scope_name, "linear") self.assertEqual(len(sonnet_collection[2].connected_subgraphs), 1) self.assertEqual( sonnet_collection[2].connected_subgraphs[0].name_scope, "linear") check() _copy_default_graph() check()
def _multihead_attention(self, memory): """Perform multi-head attention from 'Attention is All You Need'. Implementation of the attention mechanism from https://arxiv.org/abs/1706.03762. Args: memory: Memory tensor to perform attention on. Returns: new_memory: New memory tensor. """ key_size = self._key_size value_size = self._head_size qkv_size = 2 * key_size + value_size total_size = qkv_size * self._num_heads # Denote as F. qkv = basic.BatchApply(basic.Linear(total_size))(memory) qkv = basic.BatchApply(layer_norm.LayerNorm())(qkv) mem_slots = memory.get_shape().as_list()[1] # Denoted as N. # [B, N, F] -> [B, N, H, F/H] qkv_reshape = basic.BatchReshape([mem_slots, self._num_heads, qkv_size])(qkv) # [B, N, H, F/H] -> [B, H, N, F/H] qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3]) q, k, v = tf.split(qkv_transpose, [key_size, key_size, value_size], -1) q *= qkv_size ** -0.5 dot_product = tf.matmul(q, k, transpose_b=True) # [B, H, N, N] weights = tf.nn.softmax(dot_product) output = tf.matmul(weights, v) # [B, H, N, V] # [B, H, N, V] -> [B, N, H, V] output_transpose = tf.transpose(output, [0, 2, 1, 3]) # [B, N, H, V] -> [B, N, H * V] new_memory = basic.BatchFlatten(preserve_dims=2)(output_transpose) return new_memory
def _build(self, inputs, keep_prob=None, is_training=True, test_local_stats=True): """Connects the AlexNet module into the graph. Args: inputs: A Tensor of size [batch_size, input_height, input_width, input_channels], representing a batch of input images. keep_prob: A scalar Tensor representing the dropout keep probability. is_training: Boolean to indicate to `snt.BatchNorm` if we are currently training. By default `True`. test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch normalization should use local batch statistics at test time. By default `True`. Returns: A Tensor of size [batch_size, output_size], where `output_size` depends on the mode the network was constructed in. Raises: base.IncompatibleShapeError: If any of the input image dimensions (input_height, input_width) are too small for the given network mode. """ input_shape = inputs.get_shape().as_list() if input_shape[1] < self._min_size or input_shape[2] < self._min_size: raise base.IncompatibleShapeError( "Image shape too small: ({:d}, {:d}) < {:d}".format( input_shape[1], input_shape[2], self._min_size)) net = inputs for i, params in enumerate(self._conv_layers): output_channels, conv_params, max_pooling = params kernel_size, stride = conv_params conv_mod = conv.Conv2D(name="conv_{}".format(i), output_channels=output_channels, kernel_shape=kernel_size, stride=stride, padding=conv.VALID, initializers=self._initializers, partitioners=self._partitioners, regularizers=self._regularizers) if not self.is_connected: self._conv_modules.append(conv_mod) net = conv_mod(net) if self._use_batch_norm: bn = batch_norm.BatchNorm(**self._batch_norm_config) net = bn(net, is_training, test_local_stats) net = tf.nn.relu(net) if max_pooling is not None: pooling_kernel_size, pooling_stride = max_pooling net = tf.nn.max_pool( net, ksize=[1, pooling_kernel_size, pooling_kernel_size, 1], strides=[1, pooling_stride, pooling_stride, 1], padding=conv.VALID) net = basic.BatchFlatten(name="flatten")(net) for i, output_size in enumerate(self._fc_layers): linear_mod = basic.Linear(name="fc_{}".format(i), output_size=output_size, partitioners=self._partitioners) if not self.is_connected: self._linear_modules.append(linear_mod) net = linear_mod(net) if self._use_batch_norm: bn = batch_norm.BatchNorm(**self._batch_norm_config) net = bn(net, is_training, test_local_stats) net = tf.nn.relu(net) if keep_prob is not None: net = tf.nn.dropout(net, keep_prob=keep_prob) return net
def _build(self, inputs, keep_prob=None, is_training=None, test_local_stats=True): """Connects the AlexNet module into the graph. The is_training flag only controls the batch norm settings, if `False` it does not force no dropout by overriding any input `keep_prob`. To avoid any confusion this may cause, if `is_training=False` and `keep_prob` would cause dropout to be applied, an error is thrown. Args: inputs: A Tensor of size [batch_size, input_height, input_width, input_channels], representing a batch of input images. keep_prob: A scalar Tensor representing the dropout keep probability. When `is_training=False` this must be None or 1 to give no dropout. is_training: Boolean to indicate if we are currently training. Must be specified if batch normalization or dropout is used. test_local_stats: Boolean to indicate to `snt.BatchNorm` if batch normalization should use local batch statistics at test time. By default `True`. Returns: A Tensor of size [batch_size, output_size], where `output_size` depends on the mode the network was constructed in. Raises: base.IncompatibleShapeError: If any of the input image dimensions (input_height, input_width) are too small for the given network mode. ValueError: If `keep_prob` is not None or 1 when `is_training=False`. ValueError: If `is_training` is not explicitly specified when using batch normalization. """ # Check input shape if (self._use_batch_norm or keep_prob is not None) and is_training is None: raise ValueError( "Boolean is_training flag must be explicitly specified " "when using batch normalization or dropout.") input_shape = inputs.get_shape().as_list() if input_shape[1] < self._min_size or input_shape[2] < self._min_size: raise base.IncompatibleShapeError( "Image shape too small: ({:d}, {:d}) < {:d}".format( input_shape[1], input_shape[2], self._min_size)) net = inputs # Check keep prob if keep_prob is not None: valid_inputs = tf.logical_or(is_training, tf.equal(keep_prob, 1.)) keep_prob_check = tf.assert_equal( valid_inputs, True, message= "Input `keep_prob` must be None or 1 if `is_training=False`.") with tf.control_dependencies([keep_prob_check]): net = tf.identity(net) for i, params in enumerate(self._conv_layers): output_channels, conv_params, max_pooling = params kernel_size, stride = conv_params conv_mod = conv.Conv2D(name="conv_{}".format(i), output_channels=output_channels, kernel_shape=kernel_size, stride=stride, padding=conv.VALID, initializers=self._initializers, partitioners=self._partitioners, regularizers=self._regularizers) if not self.is_connected: self._conv_modules.append(conv_mod) net = conv_mod(net) if self._use_batch_norm: bn = batch_norm.BatchNorm(**self._batch_norm_config) net = bn(net, is_training, test_local_stats) net = tf.nn.relu(net) if max_pooling is not None: pooling_kernel_size, pooling_stride = max_pooling net = tf.nn.max_pool( net, ksize=[1, pooling_kernel_size, pooling_kernel_size, 1], strides=[1, pooling_stride, pooling_stride, 1], padding=conv.VALID) net = basic.BatchFlatten(name="flatten")(net) for i, output_size in enumerate(self._fc_layers): linear_mod = basic.Linear(name="fc_{}".format(i), output_size=output_size, initializers=self._initializers, partitioners=self._partitioners) if not self.is_connected: self._linear_modules.append(linear_mod) net = linear_mod(net) if self._use_batch_norm and self._bn_on_fc_layers: bn = batch_norm.BatchNorm(**self._batch_norm_config) net = bn(net, is_training, test_local_stats) net = tf.nn.relu(net) if keep_prob is not None: net = tf.nn.dropout(net, keep_prob=keep_prob) return net
def _build(self, inputs, state=None, condition=None, is_training=True, final_layer_key_value_inputs=None): """Calculates multi-layer self attention and mlp transformation. Args: inputs: Tensor of shape [batch_size, num_steps, dim_size]. state: optional list of length num_layers of tensors of shape [batch_size, memory_size, dim_size]. condition: optional tensor to condition on. The shape is shape [batch_size, dim_size]. is_training: If true, dropout is applied. final_layer_key_value_inputs: optional Tensor to be used as the key and value for the final multi-head attention layer of shape [batch_size, num_steps, dim_size]. Useful when the tower is a Seq2Seq decoder and it can attend to encoder outputs. Returns: output: tensor of shape [batch_size, num_steps, output_dim_size]. state: list of length `num_layers` containing AttentionState tuples. """ # inputs: [B, N, F] if final_layer_key_value_inputs is not None and state is not None and len( state) == (self._num_layers - 1): raise ValueError( 'When the final_layer_key_value_input is set, exclude' 'the state of the last layer.') if condition is not None: condition_tile = tf.tile(tf.expand_dims(condition, 1), [1, tf.shape(inputs)[1], 1]) inputs = tf.concat([inputs, condition_tile], -1) # Map inputs to be of `embedding_size` dimension. if inputs.get_shape().as_list()[-1] != self._embedding_size: inputs = default_mlp([self._embedding_size], activate_final=True)( inputs, is_training=is_training, dropout_keep_prob=1 - self._dropout_rate) if state is None: memory_sizes = [0] elif isinstance(state[0], CompressedMemoryState): cm_mem_size = max(_memory_size(s.compressed_memory) for s in state) em_mem_size = max(_memory_size(s.episodic_memory) for s in state) memory_sizes = [cm_mem_size, em_mem_size] else: memory_sizes = [max([_memory_size(s) for s in state])] chunk_size = inputs.get_shape().as_list()[1] self._positional_encodings = [] # Creates positional encodings for different memory types. for i, memory_size in enumerate(memory_sizes): seq_len = chunk_size + memory_size key_positions = get_position_encodings( sequence_length=seq_len, hidden_size=inputs.get_shape().as_list()[2], clamp_value=self._clamp_time_range, ) if is_training: key_positions = tf.nn.dropout(key_positions, rate=self._dropout_rate) key_positions = tf.cast(key_positions, dtype=inputs.dtype) query_positions = key_positions[:, -chunk_size:, :] self._positional_encodings.append((key_positions, query_positions)) if self._causal: self._mask = create_mask(inputs, state, self._same_attention_length) layer_i_inputs = inputs attention_states = [] key_value_inputs = None for i in range(self._num_layers): with tf.variable_scope('layer_%d' % i, reuse=tf.AUTO_REUSE): multihead_attention, object_mlp = self.get_sublayers( is_training) # Multihead attention with residuals. state_i = None if state is None else state[i] if i == (self._num_layers - 1) and final_layer_key_value_inputs is not None: # When the final_layer_key_value_inputs is set, the finaly layer # of attention will use it as the key & value, thus no need for state. key_value_inputs = final_layer_key_value_inputs state_i = None attention_outputs, attention_state = multihead_attention( layer_i_inputs, state=state_i, is_training=is_training, dropout_keep_prob=1. - self._dropout_rate, key_value_inputs=key_value_inputs) attention_states.append(attention_state) # Feed-forward with residuals. output = object_mlp(attention_outputs, is_training=is_training, dropout_keep_prob=1 - self._dropout_rate) layer_i_inputs = output if self._output_size is not None: output = basic.BatchApply( basic.Linear(self._output_size, use_bias=False))(output) return output, attention_states
def _build(self, inputs, query_inputs=None, state=None, is_training=False, dropout_keep_prob=0.5, key_value_inputs=None): """Calculates multi-layer self attention. Args: inputs: Tensor of shape [batch_size, num_steps, output_dim_size]. Inputs used as the query, key, and value to the attention layer. query_inputs: optional Tensor of shape [batch_size, num_steps, output_dim_size]. Query inputs to the attention layer. Set when query_inputs is different from the inputs argument. state: optional CompressedMemoryState or a Tensor of shape [batch_size, memory_size, dim_size] concatenated to the inputs. Set when attend to the memory from previous steps. is_training: if currently training. dropout_keep_prob: dropout rate applied to attention weights. key_value_inputs: optional Tensor of shape [batch_size, num_steps, output_dim_size]. It is used as the key and value of the multihead attention. Set when the key and value are different from the inputs argument. Returns: output: the result Tensor of shape [batch_size, num_steps, output_dim_size]. attention_state: named tuple of AttentionState. """ if key_value_inputs is not None and state is not None: raise ValueError( 'Only one of the key_value_input and state is needed.') embedding_size = self._value_size * self._num_heads q_inputs = inputs if query_inputs is None else query_inputs # Denoted by L. If query_inputs is None, L = N. _, query_size = q_inputs.get_shape().as_list()[:2] if key_value_inputs is not None: k_inputs = key_value_inputs v_inputs = k_inputs elif state is not None: if isinstance(state, CompressedMemoryState): state_memory_list = [ state.compressed_memory, state.episodic_memory ] else: state_memory_list = [state] k_inputs = tf.concat(state_memory_list + [inputs], 1) v_inputs = k_inputs else: k_inputs = inputs v_inputs = inputs # Batch size denoted by B batch_size = tf.shape(inputs)[0] # Chunk_size denoted by N chunk_size = inputs.get_shape().as_list()[1] # Denoted by N + M att_size = k_inputs.get_shape().as_list()[1] if self._positional_encodings and not self._use_relative_positions: if len(self._positional_encodings) != 1: raise ValueError( 'Absolute positional encodings only supported for 1 memory. ' 'Found %i.' % len(self._positional_encodings)) key_positions, query_positions = self._positional_encodings[0] k_inputs += key_positions q_inputs += query_positions # [B, H, L, K] q = self.multihead_linear(q_inputs, 'query') # [B, H, N + M, K] k = self.multihead_linear(k_inputs, 'key') # [B, H, N + M, V] v = self.multihead_linear(v_inputs, 'value') # Scaling the dot-product if self._scaling: q *= self._key_size**-0.5 # [B, H, L, N + M] if self._use_relative_positions: r_w_bias = tf.get_variable('r_w_bias', [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True) all_relative_logits = [] # Loop over multiple positional encodings, for the case of multiple # memory types. for i, positional_encodings in enumerate( self._positional_encodings): key_positions, query_positions = positional_encodings if key_positions.get_shape().as_list()[-1] != att_size: key_positions = key_positions[:, -att_size:] # Crop to layer mem size is_final = i == len(self._positional_encodings) - 1 suffix = '' if is_final else '_%d' % i relative_keys = self.multihead_linear(key_positions, name='relative_keys' + suffix) # [B, H, N, D] r_r_bias = tf.get_variable( 'r_r_bias' + suffix, [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1]) relative_logits = tf.matmul(q + r_r_bias, relative_keys, transpose_b=True) relative_logits = rel_shift(relative_logits) if not is_final: # Include relative positions for input sequence. relative_logits = relative_logits[:, :, :, :-chunk_size] all_relative_logits.append(relative_logits) all_relative_logits = tf.concat(all_relative_logits, 3) logits = content_logits + all_relative_logits else: # [B, H, N, N + M] logits = tf.matmul(q, k, transpose_b=True) content_logits = logits if self._mask is not None: if self._mask.get_shape().as_list()[-1] != att_size: mask = self._mask[:, :, :, -att_size:] else: mask = self._mask logits += mask weights = tf.nn.softmax(logits) if is_training: weights = tf.nn.dropout(weights, dropout_keep_prob) # [B, L, H, V], where V is value_size output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v) # [B, L, H, V] -> [B, L, HV] attended_inputs = basic.BatchReshape([query_size, embedding_size ])(output_transpose) # Apply final mlp to mix information between heads. output = basic.BatchApply( basic.Linear(embedding_size))(attended_inputs) attention_state = AttentionState(queries=q, keys=k, values=v, weights=weights, logits=content_logits, embeddings=inputs, read_words=output) return output, attention_state
def _build(self, inputs, state=None, condition=None, is_training=True): """Calculates multi-layer self attention and mlp transformation. Args: inputs: Tensor of shape [batch_size, num_steps, dim_size]. state: optional tensor of shape [batch_size, memory_size, dim_size]. condition: optional tensor to condition on. The shape is shape [batch_size, dim_size]. is_training: If true, dropout is applied. Returns: output: tensor of shape [batch_size, num_steps, output_dim_size]. state: list of length `num_layers` containing AttentionState tuples. """ # inputs: [B, N, F] if condition is not None: condition_tile = tf.tile(tf.expand_dims(condition, 1), [1, tf.shape(inputs)[1], 1]) inputs = tf.concat([inputs, condition_tile], -1) if state is None: memory_sizes = [0] elif isinstance(state[0], CompressedMemoryState): cm_mem_size = max(_memory_size(s.compressed_memory) for s in state) em_mem_size = max(_memory_size(s.episodic_memory) for s in state) memory_sizes = [cm_mem_size, em_mem_size] else: memory_sizes = [max([_memory_size(s) for s in state])] chunk_size = inputs.get_shape().as_list()[1] self._positional_encodings = [] # Creates positional encodings for different memory types. for i, memory_size in enumerate(memory_sizes): seq_len = chunk_size + memory_size key_positions = get_position_encodings( sequence_length=seq_len, hidden_size=inputs.get_shape().as_list()[2], clamp_value=self._clamp_time_range, ) if is_training: key_positions = tf.nn.dropout(key_positions, rate=self._dropout_rate) key_positions = tf.cast(key_positions, dtype=inputs.dtype) query_positions = key_positions[:, -chunk_size:, :] self._positional_encodings.append((key_positions, query_positions)) if self._causal: self._mask = create_mask(inputs, state, self._same_attention_length) layer_i_inputs = inputs attention_states = [] for i in range(self._num_layers): with tf.variable_scope('layer_%d' % i, reuse=tf.AUTO_REUSE): multihead_attention, object_mlp = self.get_sublayers( is_training) # Multihead attention with residuals. state_i = None if state is None else state[i] attention_outputs, attention_state = multihead_attention( layer_i_inputs, state=state_i, is_training=is_training, dropout_keep_prob=1. - self._dropout_rate) attention_states.append(attention_state) # Feed-forward with residuals. output = object_mlp(attention_outputs, is_training=is_training, dropout_keep_prob=1 - self._dropout_rate) layer_i_inputs = output if self._output_size is not None: output = basic.BatchApply( basic.Linear(self._output_size, use_bias=False))(output) return output, attention_states
def _build(self, inputs, query_inputs=None, state=None, is_training=False, dropout_keep_prob=0.5): embedding_size = self._value_size * self._num_heads q_inputs = inputs if query_inputs is None else query_inputs # Denoted by L. If query_inputs is None, L = N. _, query_size = q_inputs.get_shape().as_list()[:2] if state is not None: if isinstance(state, CompressedMemoryState): state_memory_list = [ state.compressed_memory, state.episodic_memory ] else: state_memory_list = [state] k_inputs = tf.concat(state_memory_list + [inputs], 1) v_inputs = k_inputs else: k_inputs = inputs v_inputs = inputs # Batch size denoted by B batch_size = tf.shape(inputs)[0] # Chunk_size denoted by N chunk_size = inputs.get_shape().as_list()[1] # Denoted by N + M att_size = k_inputs.get_shape().as_list()[1] if self._positional_encodings and not self._use_relative_positions: key_positions, query_positions = self._positional_encodings k_inputs += key_positions q_inputs += query_positions # [B, H, L, K] q = self.multihead_linear(q_inputs, 'query') # [B, H, N + M, K] k = self.multihead_linear(k_inputs, 'key') # [B, H, N + M, V] v = self.multihead_linear(v_inputs, 'value') # Scaling the dot-product if self._scaling: q *= self._key_size**-0.5 # [B, H, L, N + M] if self._use_relative_positions: r_w_bias = tf.get_variable('r_w_bias', [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) content_logits = tf.matmul(q + r_w_bias, k, transpose_b=True) all_relative_logits = [] # Loop over multiple positional encodings, for the case of multiple # memory types. for i, positional_encodings in enumerate( self._positional_encodings): key_positions, query_positions = positional_encodings if key_positions.get_shape().as_list()[-1] != att_size: key_positions = key_positions[:, -att_size:] # Crop to layer mem size is_final = i == len(self._positional_encodings) - 1 suffix = '' if is_final else '_%d' % i relative_keys = self.multihead_linear(key_positions, name='relative_keys' + suffix) # [B, H, N, D] r_r_bias = tf.get_variable( 'r_r_bias' + suffix, [1, self._num_heads, 1, self._key_size], dtype=inputs.dtype) relative_keys = tf.tile(relative_keys, [batch_size, 1, 1, 1]) relative_logits = tf.matmul(q + r_r_bias, relative_keys, transpose_b=True) relative_logits = rel_shift(relative_logits) if not is_final: # Include relative positions for input sequence. relative_logits = relative_logits[:, :, :, :-chunk_size] all_relative_logits.append(relative_logits) all_relative_logits = tf.concat(all_relative_logits, 3) logits = content_logits + all_relative_logits else: # [B, H, N, N + M] logits = tf.matmul(q, k, transpose_b=True) content_logits = logits if self._mask is not None: if self._mask.get_shape().as_list()[-1] != att_size: mask = self._mask[:, :, :, -att_size:] else: mask = self._mask logits += mask weights = tf.nn.softmax(logits) if is_training: weights = tf.nn.dropout(weights, dropout_keep_prob) # [B, L, H, V], where V is value_size output_transpose = tf.einsum('bhij,bhjk->bihk', weights, v) # [B, L, H, V] -> [B, L, HV] attended_inputs = basic.BatchReshape([query_size, embedding_size ])(output_transpose) # Apply final mlp to mix information between heads. output = basic.BatchApply( basic.Linear(embedding_size))(attended_inputs) attention_state = AttentionState(queries=q, keys=k, values=v, weights=weights, logits=content_logits, embeddings=inputs, read_words=output) return output, attention_state