def __call__(self, query, state): """ Score the query based on the keys and values. Args: query: RNN hidden state. Tensor of shape `[batch_size, num_units]`. state: IGNORED. Previous alignment values. (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). """ del state with tf.variable_scope(None, 'MultiHeadDot', [query]): # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. proj_query = tf.expand_dims(self.query_layer(query), 1) # (batch_size, 1, num_units) score = tf.multiply(self._keys, proj_query) score = split_heads( score, self._num_heads ) # (batch_size, num_heads, mem_size, num_units / num_heads) score = tf.reduce_sum(score, axis=3) # (batch_size, num_heads, mem_size) score /= tf.sqrt(self._num_units / self._num_heads) alignments = self._probability_fn(score, None) next_state = alignments _dprint('{}: Alignments shape: {}'.format(self.__class__.__name__, _shape(alignments))) return alignments, next_state
def _maybe_mask_score_multi(self, score, memory_sequence_length, score_mask_value): if memory_sequence_length is None: return score message = 'All values in memory_sequence_length must greater than zero.' with tf.control_dependencies( [tf.assert_positive(memory_sequence_length, message=message)]): print(_shape(score)) score_mask = tf.sequence_mask(memory_sequence_length, maxlen=tf.shape(score)[2]) score_mask_values = score_mask_value * tf.ones_like(score) masked_score = tf.where(score_mask, score, score_mask_values) _dprint('{}: score shape: {}'.format(self.__class__.__name__, _shape(score))) _dprint('{}: masked_score shape: {}'.format( self.__class__.__name__, _shape(masked_score))) return masked_score
def initial_alignments(self, batch_size, dtype): """Creates the initial alignment values for the `AttentionWrapper` class. This is important for AttentionMechanisms that use the previous alignment to calculate the alignment at the next time step (e.g. monotonic attention). The default behavior is to return a tensor of all zeros. Args: batch_size: `int32` scalar, the batch_size. dtype: The `dtype`. Returns: A `dtype` tensor shaped `[batch_size, alignments_size]` (`alignments_size` is the values' `max_time`). """ del batch_size s = _shape(self.values_split)[:-1] init = tf.zeros(shape=[s[0], s[1] * s[2]], dtype=dtype) _dprint('{}: Initial alignments shape: {}'.format( self.__class__.__name__, _shape(init))) return init
def state_size(self): state = super(MultiHeadAttentionWrapperV3, self).state_size _attn_mech = self._attention_mechanisms[0] s = _shape(_attn_mech._values_split)[1:3] state = state._replace(alignments=s[0] * s[1], alignment_history=s[0] * s[1] if self._alignment_history else (), attention_state=s[0] * s[1]) if _attn_mech._fm_projection is None and self._context_layer is False: state = state.clone(attention=_attn_mech._feature_map_shape[-1]) else: state = state.clone(attention=_attn_mech._num_units) _dprint('{}: state_size: {}'.format(self.__class__.__name__, state)) return state
def _layer_norm_tanh(tensor): # if version.parse(tf.__version__) >= version.parse('1.9'): try: tensor = layer_norm_activate('LN_tanh', tensor, tf.nn.tanh, begin_norm_axis=-1) except TypeError: tensor_s = _shape(tensor) tensor = layer_norm_activate('LN_tanh', tf.reshape(tensor, [-1, tensor_s[-1]]), tf.nn.tanh) tensor = tf.reshape(tensor, tensor_s) return tensor
def split_heads(x, num_heads): """Split channels (dimension 3) into multiple heads (becomes dimension 1). Args: x: a Tensor with shape [batch, length, channels] num_heads: an integer Returns: a Tensor with shape [batch, num_heads, length, channels / num_heads] """ old_shape = _shape(x) last = old_shape[-1] new_shape = old_shape[:-1] + [num_heads ] + [last // num_heads if last else -1] # new_shape = tf.concat([old_shape[:-1], [num_heads, last // num_heads]], 0) return tf.transpose(tf.reshape(x, new_shape, 'split_head'), [0, 2, 1, 3])
def __call__(self, query, state): """ Score the query based on the keys and values. Args: query: RNN hidden state. Tensor of shape `[batch_size, num_units]`. state: IGNORED. Previous alignment values. (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). """ del state with tf.variable_scope(None, 'multi_add_attention', [query]): # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. proj_query = tf.expand_dims(self.query_layer(query), 1) v = tf.get_variable('attention_v', [self._num_units], dtype=proj_query.dtype) if len(self._mask_params) > 0: v, _ = masked_layer.generate_masks(kernel=v, bias=None, dtype=proj_query.dtype, **self._mask_params) score = self._keys + proj_query score = _layer_norm_tanh(score) score = tf.multiply(score, v) score = split_heads( score, self._num_heads ) # (batch_size, num_heads, mem_size, num_units / num_heads) score = tf.reduce_sum(score, axis=3) # (batch_size, num_heads, mem_size) if self._scale: softmax_temperature = tf.get_variable( 'softmax_temperature', shape=[], dtype=tf.float32, initializer=tf.constant_initializer(5.0), collections=[ tf.GraphKeys.GLOBAL_VARIABLES, 'softmax_temperatures' ]) score = tf.truediv(score, softmax_temperature) alignments = self._probability_fn(score, None) next_state = alignments _dprint('{}: Alignments shape: {}'.format(self.__class__.__name__, _shape(alignments))) return alignments, next_state
def combine_heads(x): """Inverse of split_heads. Args: x: a Tensor with shape [batch, num_heads, length, channels / num_heads] Returns: a Tensor with shape [batch, length, channels] """ x = tf.transpose(x, [0, 2, 1, 3]) old_shape = _shape(x) a, b = old_shape[-2:] new_shape = old_shape[:-2] + [a * b if a and b else -1] # l = old_shape[2] # c = old_shape[3] # new_shape = tf.concat([old_shape[:-2] + [l * c]], 0) return tf.reshape(x, new_shape, 'combine_head')
def call(self, inputs, prev_state): """ Perform a step of attention-wrapped RNN. This method assumes `inputs` is the word embedding vector. This method overrides the original `call()` method. """ _attn_mech = self._attention_mechanisms[0] attn_size = _attn_mech._num_units batch_size = _attn_mech.batch_size dtype = inputs.dtype # Step 1: Calculate the true inputs to the cell based on the # previous attention value. # `_cell_input_fn` defaults to # `lambda inputs, attention: array_ops.concat([inputs, attention], -1)` _dprint('{}: prev_state received by call(): {}'.format( self.__class__.__name__, prev_state)) cell_inputs = self._cell_input_fn(inputs, prev_state.attention) prev_cell_state = prev_state.cell_state cell_output, curr_cell_state = self._cell(cell_inputs, prev_cell_state) cell_batch_size = (cell_output.shape[0].value or tf.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory (encoder output) " "and the query (decoder output). Are you using the " "BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with tf.control_dependencies([ tf.assert_equal(cell_batch_size, _attn_mech.batch_size, message=error_message) ]): cell_output = tf.identity(cell_output, name="checked_cell_output") dtype = cell_output.dtype assert len(self._attention_mechanisms) == 1 _attn_mech = self._attention_mechanisms[0] alignments, attention_state = _attn_mech(cell_output, state=None) if self._alignments_keep_prob < 1.: alignments = tf.contrib.layers.dropout( inputs=alignments, keep_prob=self._alignments_keep_prob, noise_shape=None, is_training=True) if len(_shape(alignments)) == 3: # Multi-head attention # Expand from [batch_size, num_heads, memory_time] to [batch_size, num_heads, 1, memory_time] expanded_alignments = tf.expand_dims(alignments, 2) # attention_mechanism.values shape is # [batch_size, num_heads, memory_time, num_units / num_heads] # the batched matmul is over memory_time, so the output shape is # [batch_size, num_heads, 1, num_units / num_heads]. # we then combine the heads # [batch_size, 1, attention_mechanism.num_units] attention_mechanism_values = _attn_mech.values_split context = tf.matmul(expanded_alignments, attention_mechanism_values) attention = tf.squeeze(combine_heads(context), [1]) else: # Expand from [batch_size, memory_time] to [batch_size, 1, memory_time] expanded_alignments = tf.expand_dims(alignments, 1) # Context is the inner product of alignments and values along the # memory time dimension. # alignments shape is # [batch_size, 1, memory_time] # attention_mechanism.values shape is # [batch_size, memory_time, attention_mechanism.num_units] # the batched matmul is over memory_time, so the output shape is # [batch_size, 1, attention_mechanism.num_units]. # we then squeeze out the singleton dim. attention_mechanism_values = _attn_mech.values context = tf.matmul(expanded_alignments, attention_mechanism_values) attention = tf.squeeze(context, [1]) # Context projection if self._context_layer: # noinspection PyCallingNonCallable attention = self._dense_layer(name='a_layer', units=_attn_mech._num_units, use_bias=False, activation=None, dtype=dtype, **self._mask_params)(attention) if self._alignment_history: alignments = tf.reshape(alignments, [cell_batch_size, -1]) alignment_history = prev_state.alignment_history.write( prev_state.time, alignments) else: alignment_history = () curr_state = attention_wrapper.AttentionWrapperState( time=prev_state.time + 1, cell_state=curr_cell_state, attention=attention, attention_state=alignments, alignments=alignments, alignment_history=alignment_history) return cell_output, curr_state
def __init__(self, num_units, feature_map, fm_projection, num_heads=None, scale=True, memory_sequence_length=None, probability_fn=tf.nn.softmax, mask_type=None, mask_init_value=0, mask_bern_sample=False, name='MultiHeadAttV3'): """ Construct the AttentionMechanism mechanism. Args: num_units: The depth of the attention mechanism. feature_map: The feature map / memory to query. This tensor should be shaped `[batch_size, height * width, channels]`. fm_projection: Feature map projection mode. num_heads: Int, number of attention heads. (optional) scale: Python boolean. Whether to scale the energy term. memory_sequence_length: Tensor indicating sequence length. probability_fn: (optional) A `callable`. Converts the score to probabilities. The default is `tf.nn.softmax`. name: Name to use when creating ops. """ logger.debug('Using MultiHeadAttV3.') assert fm_projection in [None, 'independent', 'tied'] # if memory_sequence_length is not None: # assert len(_shape(memory_sequence_length)) == 2, \ # '`memory_sequence_length` must be a rank-2 tensor, ' \ # 'shaped [batch_size, num_heads].' if mask_type is None: self._dense_layer = Dense self._mask_params = {} else: self._dense_layer = masked_layer.MaskedDense self._mask_params = dict(mask_type=mask_type, mask_init_value=mask_init_value, mask_bern_sample=mask_bern_sample) super(MultiHeadAttV3, self).__init__( query_layer=self._dense_layer(units=num_units, name='query_layer', use_bias=False, **self._mask_params), # query is projected hidden state memory_layer=self._dense_layer(units=num_units, name='memory_layer', use_bias=False, **self._mask_params), # self._keys is projected feature_map memory=feature_map, # self._values is feature_map probability_fn=lambda score, _: probability_fn(score), memory_sequence_length=None, score_mask_value=float('-inf'), name=name) self._probability_fn = lambda score, _: (probability_fn( self._maybe_mask_score_multi(score, memory_sequence_length, float('-inf')))) self._fm_projection = fm_projection self._num_units = num_units self._num_heads = num_heads self._scale = scale self._feature_map_shape = _shape(feature_map) self._name = name if fm_projection == 'tied': assert num_units % num_heads == 0, \ 'For `tied` projection, attention size/depth must be ' \ 'divisible by the number of attention heads.' self._values_split = split_heads(self._keys, self._num_heads) elif fm_projection == 'independent': assert num_units % num_heads == 0, \ 'For `untied` projection, attention size/depth must be ' \ 'divisible by the number of attention heads.' # Project and split memory v_layer = self._dense_layer(units=num_units, name='value_layer', use_bias=False, **self._mask_params) # (batch_size, num_heads, mem_size, num_units / num_heads) self._values_split = split_heads(v_layer(self._values), self._num_heads) else: assert _shape(self._values)[-1] % num_heads == 0, \ 'For `none` projection, feature map channel dim size must ' \ 'be divisible by the number of attention heads.' self._values_split = split_heads(self._values, self._num_heads) _dprint('{}: FM projection type: {}'.format(self.__class__.__name__, fm_projection)) _dprint('{}: Splitted values shape: {}'.format( self.__class__.__name__, _shape(self._values_split))) _dprint('{}: Values shape: {}'.format(self.__class__.__name__, _shape(self._values))) _dprint('{}: Keys shape: {}'.format(self.__class__.__name__, _shape(self._keys))) _dprint('{}: Feature map shape: {}'.format(self.__class__.__name__, _shape(feature_map)))