def state_size(self): state = super(MultiHeadAttentionWrapperV3, self).state_size _attn_mech = self._attention_mechanisms[0] #state = state.clone(alignments=()) s = _shape(_attn_mech._values_split)[1:3] state = state._replace(alignments=s[0] * s[1], alignment_history=s[0] * s[1], #attention_state=_attn_mech.state_size #alignment_history=s, attention_state=s[0] * s[1]) if _attn_mech._fm_projection is None and self._context_layer is False: state = state.clone(attention=_attn_mech._feature_map_shape[-1]) else: state = state.clone(attention=_attn_mech._num_units) return state
def _layer_norm_tanh(tensor): # if version.parse(tf.__version__) >= version.parse('1.9'): try: tensor = layer_norm_activate( 'LN_tanh', tensor, tf.nn.tanh, begin_norm_axis=-1) except TypeError: tensor_s = _shape(tensor) tensor = layer_norm_activate( 'LN_tanh', tf.reshape(tensor, [-1, tensor_s[-1]]), tf.nn.tanh) tensor = tf.reshape(tensor, tensor_s) return tensor
def split_heads(x, num_heads): """Split channels (dimension 3) into multiple heads (becomes dimension 1). Args: x: a Tensor with shape [batch, length, channels] num_heads: an integer Returns: a Tensor with shape [batch, num_heads, length, channels / num_heads] """ old_shape = _shape(x) last = old_shape[-1] new_shape = old_shape[:-1] + [num_heads] \ + [last // num_heads if last else -1] #new_shape = tf.concat([old_shape[:-1], [num_heads, last // num_heads]], 0) return tf.transpose(tf.reshape(x, new_shape, 'split_head'), [0, 2, 1, 3])
def combine_heads(x): """Inverse of split_heads. Args: x: a Tensor with shape [batch, num_heads, length, channels / num_heads] Returns: a Tensor with shape [batch, length, channels] """ x = tf.transpose(x, [0, 2, 1, 3]) old_shape = _shape(x) a, b = old_shape[-2:] new_shape = old_shape[:-2] + [a * b if a and b else -1] #l = old_shape[2] #c = old_shape[3] #new_shape = tf.concat([old_shape[:-2] + [l * c]], 0) return tf.reshape(x, new_shape, 'combine_head')
def initial_alignments(self, batch_size, dtype): """Creates the initial alignment values for the `AttentionWrapper` class. This is important for AttentionMechanisms that use the previous alignment to calculate the alignment at the next time step (e.g. monotonic attention). The default behavior is to return a tensor of all zeros. Args: batch_size: `int32` scalar, the batch_size. dtype: The `dtype`. Returns: A `dtype` tensor shaped `[batch_size, alignments_size]` (`alignments_size` is the values' `max_time`). """ #return tf.zeros(shape=_shape(self.values_split)[:-1]) s = _shape(self.values_split)[:-1] init = tf.zeros(shape=[s[0], s[1] * s[2]]) return init
def call(self, inputs, prev_state): """ Perform a step of attention-wrapped RNN. This method assumes `inputs` is the word embedding vector. This method overrides the original `call()` method. """ _attn_mech = self._attention_mechanisms[0] # Step 1: Calculate the true inputs to the cell based on the # previous attention value. # `_cell_input_fn` defaults to # `lambda inputs, attention: array_ops.concat([inputs, attention], -1)` cell_inputs = self._cell_input_fn(inputs, prev_state.attention) prev_cell_state = prev_state.cell_state cell_output, curr_cell_state = self._cell(cell_inputs, prev_cell_state) cell_batch_size = ( cell_output.shape[0].value or tf.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory (encoder output) " "and the query (decoder output). Are you using the " "BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with tf.control_dependencies( [tf.assert_equal(cell_batch_size, _attn_mech.batch_size, message=error_message)]): cell_output = tf.identity(cell_output, name="checked_cell_output") alignments, attention_state = _attn_mech( #cell_output, state=prev_state.attention_state) cell_output, state=None) if self._alignments_keep_prob < 1.: alignments = tf.contrib.layers.dropout( inputs=alignments, keep_prob=self._alignments_keep_prob, noise_shape=None, is_training=True) if len(_shape(alignments)) == 3: # Multi-head attention expanded_alignments = tf.expand_dims(alignments, 2) # alignments shape is # [batch_size, num_heads, 1, memory_time] # attention_mechanism.values shape is # [batch_size, num_heads, memory_time, num_units / num_heads] # the batched matmul is over memory_time, so the output shape is # [batch_size, num_heads, 1, num_units / num_heads]. # we then combine the heads # [batch_size, 1, attention_mechanism.num_units] attention_mechanism_values = _attn_mech.values_split context = tf.matmul(expanded_alignments, attention_mechanism_values) attention = tf.squeeze(combine_heads(context), [1]) else: # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] expanded_alignments = tf.expand_dims(alignments, 1) # Context is the inner product of alignments and values along the # memory time dimension. # alignments shape is # [batch_size, 1, memory_time] # attention_mechanism.values shape is # [batch_size, memory_time, attention_mechanism.num_units] # the batched matmul is over memory_time, so the output shape is # [batch_size, 1, attention_mechanism.num_units]. # we then squeeze out the singleton dim. attention_mechanism_values = _attn_mech.values context = tf.matmul(expanded_alignments, attention_mechanism_values) attention = tf.squeeze(context, [1]) # Context projection if self._context_layer: attention = Dense(name='a_layer', units=_attn_mech._num_units, use_bias=False, activation=None, dtype=_attn_mech.dtype)(attention) if self._alignment_history: alignments = tf.reshape(alignments, [cell_batch_size, -1]) alignment_history = prev_state.alignment_history.write( prev_state.time, alignments) else: alignment_history = () curr_state = attention_wrapper.AttentionWrapperState( time=prev_state.time + 1, cell_state=curr_cell_state, attention=attention, attention_state=alignments, alignments=alignments, alignment_history=alignment_history) return cell_output, curr_state
def __init__(self, num_units, feature_map, fm_projection, num_heads=None, scale=True, memory_sequence_length=None, probability_fn=tf.nn.softmax, name='MultiHeadAttV3'): """ Construct the AttentionMechanism mechanism. Args: num_units: The depth of the attention mechanism. feature_map: The feature map / memory to query. This tensor should be shaped `[batch_size, height * width, channels]`. attention_type: String from 'single', 'multi_add', 'multi_dot'. reuse_keys_as_values: Boolean, whether to use keys as values. fm_projection: Feature map projection mode. num_heads: Int, number of attention heads. (optional) scale: Python boolean. Whether to scale the energy term. probability_fn: (optional) A `callable`. Converts the score to probabilities. The default is `tf.nn.softmax`. name: Name to use when creating ops. """ print('INFO: Using MultiHeadAttV3.') assert fm_projection in [None, 'independent', 'tied'] if memory_sequence_length is not None: assert len(_shape(memory_sequence_length)) == 2, \ '`memory_sequence_length` must be a rank-2 tensor, ' \ 'shaped [batch_size, num_heads].' super(MultiHeadAttV3, self).__init__( query_layer=Dense(num_units, name='query_layer', use_bias=False), # query is projected hidden state memory_layer=Dense(num_units, name='memory_layer', use_bias=False), # self._keys is projected feature_map memory=feature_map, # self._values is feature_map probability_fn=lambda score, _: probability_fn(score), memory_sequence_length=None, score_mask_value=float('-inf'), name=name) self._probability_fn = lambda score, _: ( probability_fn( self._maybe_mask_score_multi( score, memory_sequence_length, float('-inf')))) self._fm_projection = fm_projection self._num_units = num_units self._num_heads = num_heads self._scale = scale self._feature_map_shape = _shape(feature_map) self._name = name if fm_projection == 'tied': assert num_units % num_heads == 0, \ 'For `tied` projection, attention size/depth must be ' \ 'divisible by the number of attention heads.' self._values_split = split_heads(self._keys, self._num_heads) elif fm_projection == 'independent': assert num_units % num_heads == 0, \ 'For `untied` projection, attention size/depth must be ' \ 'divisible by the number of attention heads.' # Project and split memory v_layer = Dense(num_units, name='value_layer', use_bias=False) # (batch_size, num_heads, mem_size, num_units / num_heads) self._values_split = split_heads(v_layer(self._values), self._num_heads) else: assert _shape(self._values)[-1] % num_heads == 0, \ 'For `none` projection, feature map channel dim size must ' \ 'be divisible by the number of attention heads.' self._values_split = split_heads(self._values, self._num_heads)
def __init__(self, name, num_units, memory, memory_projection='independent', memory_sequence_length=None, score_scale=True, probability_fn=None, dtype=None): """ Construct the Attention mechanism. Args: name: Name to use when creating ops and variables. num_units: The depth of the query mechanism. memory: The memory to query, shaped NHWC. fmap_projection: Either `tied` or `independent`. Determines the projection mode used by the attention MLP. score_scale: Python boolean. Whether to use softmax temperature. probability_fn: (optional) A `callable`. Converts the score to probabilities. The default is @{tf.nn.softmax}. Other options include @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. Its signature should be: `probabilities = probability_fn(score)`. dtype: The data type for the query and memory layers of the attention mechanism. """ print('INFO: Using {}.'.format(self.__class__.__name__)) if probability_fn is None: probability_fn = tf.nn.softmax if dtype is None: dtype = tf.float32 wrapped_probability_fn = lambda score, _: probability_fn(score) assert memory_projection in ['independent', 'tied'] assert len(_shape(memory)) == 3, \ 'The CNN feature maps must be a rank-3 tensor of NTC.' proj_kwargs = dict( units=num_units, use_bias=True, activation=None, dtype=dtype) with tf.variable_scope(name): super(BahdanauAttentionV1, self).__init__( query_layer=Dense(name='query_layer', **proj_kwargs), memory_layer=Dense(name='memory_layer', **proj_kwargs), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, score_mask_value=None, name=name) self._num_units = num_units self._memory_projection = memory_projection self._score_scale = score_scale self._name = name if self._memory_projection == 'tied': self._values = tf.identity(self._keys) elif self._memory_projection == 'independent': # Project memory self._values = Dense( name='value_layer', **proj_kwargs)(self._values) else: raise ValueError('Undefined.')