def state_size(self):
     state = super(MultiHeadAttentionWrapperV3, self).state_size
     _attn_mech = self._attention_mechanisms[0]
     #state = state.clone(alignments=())
     s = _shape(_attn_mech._values_split)[1:3]
     state = state._replace(alignments=s[0] * s[1],
                            alignment_history=s[0] * s[1],
                            #attention_state=_attn_mech.state_size
                            #alignment_history=s,
                            attention_state=s[0] * s[1])
     if _attn_mech._fm_projection is None and self._context_layer is False:
         state = state.clone(attention=_attn_mech._feature_map_shape[-1])
     else:
         state = state.clone(attention=_attn_mech._num_units)
     return state
def _layer_norm_tanh(tensor):
    # if version.parse(tf.__version__) >= version.parse('1.9'):
    try:
        tensor = layer_norm_activate(
            'LN_tanh',
            tensor,
            tf.nn.tanh,
            begin_norm_axis=-1)
    except TypeError:
        tensor_s = _shape(tensor)
        tensor = layer_norm_activate(
            'LN_tanh',
            tf.reshape(tensor, [-1, tensor_s[-1]]),
            tf.nn.tanh)
        tensor = tf.reshape(tensor, tensor_s)
    return tensor
def split_heads(x, num_heads):
    """Split channels (dimension 3) into multiple heads (becomes dimension 1).

    Args:
        x: a Tensor with shape [batch, length, channels]
        num_heads: an integer

    Returns:
        a Tensor with shape [batch, num_heads, length, channels / num_heads]
    """
    old_shape = _shape(x)
    last = old_shape[-1]
    new_shape = old_shape[:-1] + [num_heads] \
                + [last // num_heads if last else -1]
    #new_shape = tf.concat([old_shape[:-1], [num_heads, last // num_heads]], 0)
    return tf.transpose(tf.reshape(x, new_shape, 'split_head'), [0, 2, 1, 3])
def combine_heads(x):
    """Inverse of split_heads.

    Args:
        x: a Tensor with shape [batch, num_heads, length, channels / num_heads]

    Returns:
        a Tensor with shape [batch, length, channels]
    """
    x = tf.transpose(x, [0, 2, 1, 3])
    old_shape = _shape(x)
    a, b = old_shape[-2:]
    new_shape = old_shape[:-2] + [a * b if a and b else -1]
    #l = old_shape[2]
    #c = old_shape[3]
    #new_shape = tf.concat([old_shape[:-2] + [l * c]], 0)
    return tf.reshape(x, new_shape, 'combine_head')
    def initial_alignments(self, batch_size, dtype):
        """Creates the initial alignment values for the `AttentionWrapper` class.
        
        This is important for AttentionMechanisms that use the previous alignment
        to calculate the alignment at the next time step (e.g. monotonic attention).
        
        The default behavior is to return a tensor of all zeros.
        
        Args:
            batch_size: `int32` scalar, the batch_size.
            dtype: The `dtype`.

        Returns:
            A `dtype` tensor shaped `[batch_size, alignments_size]`
            (`alignments_size` is the values' `max_time`).
        """
        #return tf.zeros(shape=_shape(self.values_split)[:-1])
        s = _shape(self.values_split)[:-1]
        init = tf.zeros(shape=[s[0], s[1] * s[2]])
        return init
 def call(self, inputs, prev_state):
     """
     Perform a step of attention-wrapped RNN.
     
     This method assumes `inputs` is the word embedding vector.
     
     This method overrides the original `call()` method.
     """
     _attn_mech = self._attention_mechanisms[0]
     # Step 1: Calculate the true inputs to the cell based on the
     # previous attention value.
     # `_cell_input_fn` defaults to
     # `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`
     cell_inputs = self._cell_input_fn(inputs, prev_state.attention)
     prev_cell_state = prev_state.cell_state
     cell_output, curr_cell_state = self._cell(cell_inputs, prev_cell_state)
     
     cell_batch_size = (
             cell_output.shape[0].value or tf.shape(cell_output)[0])
     error_message = (
             "When applying AttentionWrapper %s: " % self.name +
             "Non-matching batch sizes between the memory (encoder output) "
             "and the query (decoder output). Are you using the "
             "BeamSearchDecoder? You may need to tile your memory input via "
             "the tf.contrib.seq2seq.tile_batch function with argument "
             "multiple=beam_width.")
     with tf.control_dependencies(
                     [tf.assert_equal(cell_batch_size,
                                      _attn_mech.batch_size,
                                      message=error_message)]):
         cell_output = tf.identity(cell_output, name="checked_cell_output")
     
     alignments, attention_state = _attn_mech(
                             #cell_output, state=prev_state.attention_state)
                             cell_output, state=None)
     
     if self._alignments_keep_prob < 1.:
         alignments = tf.contrib.layers.dropout(
                                     inputs=alignments,
                                     keep_prob=self._alignments_keep_prob,
                                     noise_shape=None,
                                     is_training=True)
     
     if len(_shape(alignments)) == 3:
         # Multi-head attention
         expanded_alignments = tf.expand_dims(alignments, 2)
         # alignments shape is
         #     [batch_size, num_heads, 1, memory_time]
         # attention_mechanism.values shape is
         #     [batch_size, num_heads, memory_time, num_units / num_heads]
         # the batched matmul is over memory_time, so the output shape is
         #     [batch_size, num_heads, 1, num_units / num_heads].
         # we then combine the heads
         #     [batch_size, 1, attention_mechanism.num_units]
         attention_mechanism_values = _attn_mech.values_split
         context = tf.matmul(expanded_alignments, attention_mechanism_values)
         attention = tf.squeeze(combine_heads(context), [1])
     else:
         # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
         expanded_alignments = tf.expand_dims(alignments, 1)
         # Context is the inner product of alignments and values along the
         # memory time dimension.
         # alignments shape is
         #     [batch_size, 1, memory_time]
         # attention_mechanism.values shape is
         #     [batch_size, memory_time, attention_mechanism.num_units]
         # the batched matmul is over memory_time, so the output shape is
         #     [batch_size, 1, attention_mechanism.num_units].
         # we then squeeze out the singleton dim.
         attention_mechanism_values = _attn_mech.values
         context = tf.matmul(expanded_alignments, attention_mechanism_values)
         attention = tf.squeeze(context, [1])
     
     # Context projection
     if self._context_layer:
         attention = Dense(name='a_layer',
                           units=_attn_mech._num_units,
                           use_bias=False,
                           activation=None,
                           dtype=_attn_mech.dtype)(attention)
     
     if self._alignment_history:
         alignments = tf.reshape(alignments, [cell_batch_size, -1])
         alignment_history = prev_state.alignment_history.write(
                                                 prev_state.time, alignments)
     else:
         alignment_history = ()
     
     curr_state = attention_wrapper.AttentionWrapperState(
                         time=prev_state.time + 1,
                         cell_state=curr_cell_state,
                         attention=attention,
                         attention_state=alignments,
                         alignments=alignments,
                         alignment_history=alignment_history)
     return cell_output, curr_state
 def __init__(self,
              num_units,
              feature_map,
              fm_projection,
              num_heads=None,
              scale=True,
              memory_sequence_length=None,
              probability_fn=tf.nn.softmax,
              name='MultiHeadAttV3'):
     """
     Construct the AttentionMechanism mechanism.
     Args:
         num_units: The depth of the attention mechanism.
         feature_map: The feature map / memory to query. This tensor
             should be shaped `[batch_size, height * width, channels]`.
         attention_type: String from 'single', 'multi_add', 'multi_dot'.
         reuse_keys_as_values: Boolean, whether to use keys as values.
         fm_projection: Feature map projection mode.
         num_heads: Int, number of attention heads. (optional)
         scale: Python boolean.  Whether to scale the energy term.
         probability_fn: (optional) A `callable`.  Converts the score
             to probabilities.  The default is `tf.nn.softmax`.
         name: Name to use when creating ops.
     """
     print('INFO: Using MultiHeadAttV3.')
     assert fm_projection in [None, 'independent', 'tied']
     
     if memory_sequence_length is not None:
         assert len(_shape(memory_sequence_length)) == 2, \
             '`memory_sequence_length` must be a rank-2 tensor, ' \
             'shaped [batch_size, num_heads].'
     
     super(MultiHeadAttV3, self).__init__(
         query_layer=Dense(num_units, name='query_layer', use_bias=False),       # query is projected hidden state
         memory_layer=Dense(num_units, name='memory_layer', use_bias=False),     # self._keys is projected feature_map
         memory=feature_map,                                                     # self._values is feature_map
         probability_fn=lambda score, _: probability_fn(score),
         memory_sequence_length=None,
         score_mask_value=float('-inf'),
         name=name)
     
     self._probability_fn = lambda score, _: (
         probability_fn(
             self._maybe_mask_score_multi(
                 score, memory_sequence_length, float('-inf'))))
     self._fm_projection = fm_projection
     self._num_units = num_units
     self._num_heads = num_heads
     self._scale = scale
     self._feature_map_shape = _shape(feature_map)
     self._name = name
     
     if fm_projection == 'tied':
         assert num_units % num_heads == 0, \
             'For `tied` projection, attention size/depth must be ' \
             'divisible by the number of attention heads.'
         self._values_split = split_heads(self._keys, self._num_heads)
     elif fm_projection == 'independent':
         assert num_units % num_heads == 0, \
             'For `untied` projection, attention size/depth must be ' \
             'divisible by the number of attention heads.'
         # Project and split memory
         v_layer = Dense(num_units, name='value_layer', use_bias=False)
         # (batch_size, num_heads, mem_size, num_units / num_heads)
         self._values_split = split_heads(v_layer(self._values), self._num_heads)
     else:
         assert _shape(self._values)[-1] % num_heads == 0, \
             'For `none` projection, feature map channel dim size must ' \
             'be divisible by the number of attention heads.'
         self._values_split = split_heads(self._values, self._num_heads)
 def __init__(self,
              name,
              num_units,
              memory,
              memory_projection='independent',
              memory_sequence_length=None,
              score_scale=True,
              probability_fn=None,
              dtype=None):
     """
     Construct the Attention mechanism.
     
     Args:
         name: Name to use when creating ops and variables.
         num_units: The depth of the query mechanism.
         memory: The memory to query, shaped NHWC.
         fmap_projection: Either `tied` or `independent`. Determines the 
             projection mode used by the attention MLP.
         score_scale: Python boolean.  Whether to use softmax temperature.
         probability_fn: (optional) A `callable`.  Converts the score to
             probabilities.  The default is @{tf.nn.softmax}. Other options include
             @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
             Its signature should be: `probabilities = probability_fn(score)`.
         dtype: The data type for the query and memory layers of the attention
             mechanism.
     """
     print('INFO: Using {}.'.format(self.__class__.__name__))
     if probability_fn is None:
         probability_fn = tf.nn.softmax
     if dtype is None:
         dtype = tf.float32
     wrapped_probability_fn = lambda score, _: probability_fn(score)
     assert memory_projection in ['independent', 'tied']
     
     assert len(_shape(memory)) == 3, \
         'The CNN feature maps must be a rank-3 tensor of NTC.'
     
     proj_kwargs = dict(
                     units=num_units,
                     use_bias=True,
                     activation=None,
                     dtype=dtype)
     with tf.variable_scope(name):
         super(BahdanauAttentionV1, self).__init__(
             query_layer=Dense(name='query_layer', **proj_kwargs),
             memory_layer=Dense(name='memory_layer', **proj_kwargs),
             memory=memory,
             probability_fn=wrapped_probability_fn,
             memory_sequence_length=memory_sequence_length,
             score_mask_value=None,
             name=name)
         self._num_units = num_units
         self._memory_projection = memory_projection
         self._score_scale = score_scale
         self._name = name
         
         if self._memory_projection == 'tied':
             self._values = tf.identity(self._keys)
         elif self._memory_projection == 'independent':
             # Project memory
             self._values = Dense(
                         name='value_layer',
                         **proj_kwargs)(self._values)
         else:
             raise ValueError('Undefined.')