def __call__(self, query, state):
     """
     Score the query based on the keys and values.
     Args:
         query: RNN hidden state. Tensor of shape `[batch_size, num_units]`.
         state: IGNORED. Previous alignment values.
             (`alignments_size` is memory's `max_time`).
     Returns:
         alignments: Tensor of dtype matching `self.values` and shape
             `[batch_size, alignments_size]` (`alignments_size` is memory's
             `max_time`).
     """
     del state
     with tf.variable_scope(None, 'MultiHeadDot', [query]):
         # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
         proj_query = tf.expand_dims(self.query_layer(query),
                                     1)  # (batch_size, 1, num_units)
         score = tf.multiply(self._keys, proj_query)
         score = split_heads(
             score, self._num_heads
         )  # (batch_size, num_heads, mem_size, num_units / num_heads)
         score = tf.reduce_sum(score,
                               axis=3)  # (batch_size, num_heads, mem_size)
         score /= tf.sqrt(self._num_units / self._num_heads)
     alignments = self._probability_fn(score, None)
     next_state = alignments
     _dprint('{}: Alignments shape: {}'.format(self.__class__.__name__,
                                               _shape(alignments)))
     return alignments, next_state
 def _maybe_mask_score_multi(self, score, memory_sequence_length,
                             score_mask_value):
     if memory_sequence_length is None:
         return score
     message = 'All values in memory_sequence_length must greater than zero.'
     with tf.control_dependencies(
         [tf.assert_positive(memory_sequence_length, message=message)]):
         print(_shape(score))
         score_mask = tf.sequence_mask(memory_sequence_length,
                                       maxlen=tf.shape(score)[2])
         score_mask_values = score_mask_value * tf.ones_like(score)
         masked_score = tf.where(score_mask, score, score_mask_values)
         _dprint('{}: score shape: {}'.format(self.__class__.__name__,
                                              _shape(score)))
         _dprint('{}: masked_score shape: {}'.format(
             self.__class__.__name__, _shape(masked_score)))
         return masked_score
    def initial_alignments(self, batch_size, dtype):
        """Creates the initial alignment values for the `AttentionWrapper` class.

        This is important for AttentionMechanisms that use the previous alignment
        to calculate the alignment at the next time step (e.g. monotonic attention).

        The default behavior is to return a tensor of all zeros.

        Args:
            batch_size: `int32` scalar, the batch_size.
            dtype: The `dtype`.

        Returns:
            A `dtype` tensor shaped `[batch_size, alignments_size]`
            (`alignments_size` is the values' `max_time`).
        """
        del batch_size
        s = _shape(self.values_split)[:-1]
        init = tf.zeros(shape=[s[0], s[1] * s[2]], dtype=dtype)
        _dprint('{}: Initial alignments shape: {}'.format(
            self.__class__.__name__, _shape(init)))
        return init
 def state_size(self):
     state = super(MultiHeadAttentionWrapperV3, self).state_size
     _attn_mech = self._attention_mechanisms[0]
     s = _shape(_attn_mech._values_split)[1:3]
     state = state._replace(alignments=s[0] * s[1],
                            alignment_history=s[0] *
                            s[1] if self._alignment_history else (),
                            attention_state=s[0] * s[1])
     if _attn_mech._fm_projection is None and self._context_layer is False:
         state = state.clone(attention=_attn_mech._feature_map_shape[-1])
     else:
         state = state.clone(attention=_attn_mech._num_units)
     _dprint('{}: state_size: {}'.format(self.__class__.__name__, state))
     return state
def _layer_norm_tanh(tensor):
    # if version.parse(tf.__version__) >= version.parse('1.9'):
    try:
        tensor = layer_norm_activate('LN_tanh',
                                     tensor,
                                     tf.nn.tanh,
                                     begin_norm_axis=-1)
    except TypeError:
        tensor_s = _shape(tensor)
        tensor = layer_norm_activate('LN_tanh',
                                     tf.reshape(tensor, [-1, tensor_s[-1]]),
                                     tf.nn.tanh)
        tensor = tf.reshape(tensor, tensor_s)
    return tensor
def split_heads(x, num_heads):
    """Split channels (dimension 3) into multiple heads (becomes dimension 1).

    Args:
        x: a Tensor with shape [batch, length, channels]
        num_heads: an integer

    Returns:
        a Tensor with shape [batch, num_heads, length, channels / num_heads]
    """
    old_shape = _shape(x)
    last = old_shape[-1]
    new_shape = old_shape[:-1] + [num_heads
                                  ] + [last // num_heads if last else -1]
    # new_shape = tf.concat([old_shape[:-1], [num_heads, last // num_heads]], 0)
    return tf.transpose(tf.reshape(x, new_shape, 'split_head'), [0, 2, 1, 3])
    def __call__(self, query, state):
        """
        Score the query based on the keys and values.
        Args:
            query: RNN hidden state. Tensor of shape `[batch_size, num_units]`.
            state: IGNORED. Previous alignment values.
                (`alignments_size` is memory's `max_time`).
        Returns:
            alignments: Tensor of dtype matching `self.values` and shape
                `[batch_size, alignments_size]` (`alignments_size` is memory's
                `max_time`).
        """
        del state
        with tf.variable_scope(None, 'multi_add_attention', [query]):
            # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
            proj_query = tf.expand_dims(self.query_layer(query), 1)
            v = tf.get_variable('attention_v', [self._num_units],
                                dtype=proj_query.dtype)
            if len(self._mask_params) > 0:
                v, _ = masked_layer.generate_masks(kernel=v,
                                                   bias=None,
                                                   dtype=proj_query.dtype,
                                                   **self._mask_params)
            score = self._keys + proj_query
            score = _layer_norm_tanh(score)
            score = tf.multiply(score, v)
            score = split_heads(
                score, self._num_heads
            )  # (batch_size, num_heads, mem_size, num_units / num_heads)
            score = tf.reduce_sum(score,
                                  axis=3)  # (batch_size, num_heads, mem_size)

        if self._scale:
            softmax_temperature = tf.get_variable(
                'softmax_temperature',
                shape=[],
                dtype=tf.float32,
                initializer=tf.constant_initializer(5.0),
                collections=[
                    tf.GraphKeys.GLOBAL_VARIABLES, 'softmax_temperatures'
                ])
            score = tf.truediv(score, softmax_temperature)
        alignments = self._probability_fn(score, None)
        next_state = alignments
        _dprint('{}: Alignments shape: {}'.format(self.__class__.__name__,
                                                  _shape(alignments)))
        return alignments, next_state
def combine_heads(x):
    """Inverse of split_heads.

    Args:
        x: a Tensor with shape [batch, num_heads, length, channels / num_heads]

    Returns:
        a Tensor with shape [batch, length, channels]
    """
    x = tf.transpose(x, [0, 2, 1, 3])
    old_shape = _shape(x)
    a, b = old_shape[-2:]
    new_shape = old_shape[:-2] + [a * b if a and b else -1]
    # l = old_shape[2]
    # c = old_shape[3]
    # new_shape = tf.concat([old_shape[:-2] + [l * c]], 0)
    return tf.reshape(x, new_shape, 'combine_head')
    def call(self, inputs, prev_state):
        """
        Perform a step of attention-wrapped RNN.

        This method assumes `inputs` is the word embedding vector.

        This method overrides the original `call()` method.
        """
        _attn_mech = self._attention_mechanisms[0]
        attn_size = _attn_mech._num_units
        batch_size = _attn_mech.batch_size
        dtype = inputs.dtype

        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        # `_cell_input_fn` defaults to
        # `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`
        _dprint('{}: prev_state received by call(): {}'.format(
            self.__class__.__name__, prev_state))
        cell_inputs = self._cell_input_fn(inputs, prev_state.attention)
        prev_cell_state = prev_state.cell_state
        cell_output, curr_cell_state = self._cell(cell_inputs, prev_cell_state)

        cell_batch_size = (cell_output.shape[0].value
                           or tf.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory (encoder output) "
            "and the query (decoder output). Are you using the "
            "BeamSearchDecoder? You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with tf.control_dependencies([
                tf.assert_equal(cell_batch_size,
                                _attn_mech.batch_size,
                                message=error_message)
        ]):
            cell_output = tf.identity(cell_output, name="checked_cell_output")

        dtype = cell_output.dtype
        assert len(self._attention_mechanisms) == 1
        _attn_mech = self._attention_mechanisms[0]
        alignments, attention_state = _attn_mech(cell_output, state=None)

        if self._alignments_keep_prob < 1.:
            alignments = tf.contrib.layers.dropout(
                inputs=alignments,
                keep_prob=self._alignments_keep_prob,
                noise_shape=None,
                is_training=True)

        if len(_shape(alignments)) == 3:
            # Multi-head attention
            # Expand from [batch_size, num_heads, memory_time] to [batch_size, num_heads, 1, memory_time]
            expanded_alignments = tf.expand_dims(alignments, 2)
            # attention_mechanism.values shape is
            #     [batch_size, num_heads, memory_time, num_units / num_heads]
            # the batched matmul is over memory_time, so the output shape is
            #     [batch_size, num_heads, 1, num_units / num_heads].
            # we then combine the heads
            #     [batch_size, 1, attention_mechanism.num_units]
            attention_mechanism_values = _attn_mech.values_split
            context = tf.matmul(expanded_alignments,
                                attention_mechanism_values)
            attention = tf.squeeze(combine_heads(context), [1])
        else:
            # Expand from [batch_size, memory_time] to [batch_size, 1, memory_time]
            expanded_alignments = tf.expand_dims(alignments, 1)
            # Context is the inner product of alignments and values along the
            # memory time dimension.
            # alignments shape is
            #     [batch_size, 1, memory_time]
            # attention_mechanism.values shape is
            #     [batch_size, memory_time, attention_mechanism.num_units]
            # the batched matmul is over memory_time, so the output shape is
            #     [batch_size, 1, attention_mechanism.num_units].
            # we then squeeze out the singleton dim.
            attention_mechanism_values = _attn_mech.values
            context = tf.matmul(expanded_alignments,
                                attention_mechanism_values)
            attention = tf.squeeze(context, [1])

        # Context projection
        if self._context_layer:
            # noinspection PyCallingNonCallable
            attention = self._dense_layer(name='a_layer',
                                          units=_attn_mech._num_units,
                                          use_bias=False,
                                          activation=None,
                                          dtype=dtype,
                                          **self._mask_params)(attention)

        if self._alignment_history:
            alignments = tf.reshape(alignments, [cell_batch_size, -1])
            alignment_history = prev_state.alignment_history.write(
                prev_state.time, alignments)
        else:
            alignment_history = ()

        curr_state = attention_wrapper.AttentionWrapperState(
            time=prev_state.time + 1,
            cell_state=curr_cell_state,
            attention=attention,
            attention_state=alignments,
            alignments=alignments,
            alignment_history=alignment_history)
        return cell_output, curr_state
    def __init__(self,
                 num_units,
                 feature_map,
                 fm_projection,
                 num_heads=None,
                 scale=True,
                 memory_sequence_length=None,
                 probability_fn=tf.nn.softmax,
                 mask_type=None,
                 mask_init_value=0,
                 mask_bern_sample=False,
                 name='MultiHeadAttV3'):
        """
        Construct the AttentionMechanism mechanism.
        Args:
            num_units: The depth of the attention mechanism.
            feature_map: The feature map / memory to query. This tensor
                should be shaped `[batch_size, height * width, channels]`.
            fm_projection: Feature map projection mode.
            num_heads: Int, number of attention heads. (optional)
            scale: Python boolean.  Whether to scale the energy term.
            memory_sequence_length: Tensor indicating sequence length.
            probability_fn: (optional) A `callable`.  Converts the score
                to probabilities.  The default is `tf.nn.softmax`.
            name: Name to use when creating ops.
        """
        logger.debug('Using MultiHeadAttV3.')
        assert fm_projection in [None, 'independent', 'tied']

        # if memory_sequence_length is not None:
        #     assert len(_shape(memory_sequence_length)) == 2, \
        #         '`memory_sequence_length` must be a rank-2 tensor, ' \
        #         'shaped [batch_size, num_heads].'

        if mask_type is None:
            self._dense_layer = Dense
            self._mask_params = {}
        else:
            self._dense_layer = masked_layer.MaskedDense
            self._mask_params = dict(mask_type=mask_type,
                                     mask_init_value=mask_init_value,
                                     mask_bern_sample=mask_bern_sample)

        super(MultiHeadAttV3, self).__init__(
            query_layer=self._dense_layer(units=num_units,
                                          name='query_layer',
                                          use_bias=False,
                                          **self._mask_params),
            # query is projected hidden state
            memory_layer=self._dense_layer(units=num_units,
                                           name='memory_layer',
                                           use_bias=False,
                                           **self._mask_params),
            # self._keys is projected feature_map
            memory=feature_map,  # self._values is feature_map
            probability_fn=lambda score, _: probability_fn(score),
            memory_sequence_length=None,
            score_mask_value=float('-inf'),
            name=name)

        self._probability_fn = lambda score, _: (probability_fn(
            self._maybe_mask_score_multi(score, memory_sequence_length,
                                         float('-inf'))))
        self._fm_projection = fm_projection
        self._num_units = num_units
        self._num_heads = num_heads
        self._scale = scale
        self._feature_map_shape = _shape(feature_map)
        self._name = name

        if fm_projection == 'tied':
            assert num_units % num_heads == 0, \
                'For `tied` projection, attention size/depth must be ' \
                'divisible by the number of attention heads.'
            self._values_split = split_heads(self._keys, self._num_heads)
        elif fm_projection == 'independent':
            assert num_units % num_heads == 0, \
                'For `untied` projection, attention size/depth must be ' \
                'divisible by the number of attention heads.'
            # Project and split memory
            v_layer = self._dense_layer(units=num_units,
                                        name='value_layer',
                                        use_bias=False,
                                        **self._mask_params)
            # (batch_size, num_heads, mem_size, num_units / num_heads)
            self._values_split = split_heads(v_layer(self._values),
                                             self._num_heads)
        else:
            assert _shape(self._values)[-1] % num_heads == 0, \
                'For `none` projection, feature map channel dim size must ' \
                'be divisible by the number of attention heads.'
            self._values_split = split_heads(self._values, self._num_heads)

        _dprint('{}: FM projection type: {}'.format(self.__class__.__name__,
                                                    fm_projection))
        _dprint('{}: Splitted values shape: {}'.format(
            self.__class__.__name__, _shape(self._values_split)))
        _dprint('{}: Values shape: {}'.format(self.__class__.__name__,
                                              _shape(self._values)))
        _dprint('{}: Keys shape: {}'.format(self.__class__.__name__,
                                            _shape(self._keys)))
        _dprint('{}: Feature map shape: {}'.format(self.__class__.__name__,
                                                   _shape(feature_map)))