def _multiplicative_attn(self, queries, keys, values, attn_mask):
        """ Uses multiplicative attention to compute contextually enriched source-side representations. """
        # Account for beam-search
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.tile(keys, [num_beams, 1, 1])
        values = tf.tile(values, [num_beams, 1, 1])

        def _logits_fn(query):
            """ Computes time-step-wise attention scores. """
            query = tf.expand_dims(query, 1)
            return tf.multiply(keys, query)

        # Obtain attention scores
        transposed_queries = tf.transpose(queries, [1, 0, 2])  # time-major
        # attn_logits has shape=[time_steps_q, batch_size, time_steps_k, num_features]
        attn_logits = tf.map_fn(_logits_fn, transposed_queries)

        if attn_mask is not None:
            transposed_mask = \
                tf.transpose(tf.tile(attn_mask, [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1, 1]),
                             [2, 0, 3, 1])
            attn_logits += transposed_mask

        # Compute the attention weights
        attn_weights = tf.nn.softmax(attn_logits, axis=-2, name='attn_weights')
        # Optionally apply dropout
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training)

        # Obtain context vectors
        expanded_values = tf.expand_dims(values, axis=1)
        weighted_memories = \
            tf.reduce_sum(tf.multiply(tf.transpose(attn_weights, [1, 0, 2, 3]), expanded_values), axis=2)
        return weighted_memories
예제 #2
0
def create_attention_mask_from_input_mask(from_tensor, to_mask):
    """Create 3D attention mask from a 2D tensor mask.
  Args:
    from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
    to_mask: int32 Tensor of shape [batch_size, to_seq_length].
  Returns:
    float Tensor of shape [batch_size, from_seq_length, to_seq_length].
  """
    from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
    batch_size = from_shape[0]
    from_seq_length = from_shape[1]

    to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
    to_seq_length = to_shape[1]

    to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
                      dtype=from_tensor.dtype)

    # We don't assume that `from_tensor` is a mask (although it could be). We
    # don't actually care if we attend *from* padding tokens (only *to* padding)
    # tokens so we create a tensor of all ones.
    #
    # `broadcast_ones` = [batch_size, from_seq_length, 1]
    broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1],
                             dtype=from_tensor.dtype)

    # Here we broadcast along two dimensions to create the mask.
    mask = broadcast_ones * to_mask

    return mask
    def _multiplicative_attn(self, queries, keys, values, attn_mask):
        """ Uses multiplicative attention to compute contextually enriched source-side representations. """
        # Account for beam-search
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.tile(keys, [num_beams, 1, 1])
        values = tf.tile(values, [num_beams, 1, 1])

        # Use multiplicative attention
        transposed_keys = tf.transpose(keys, [0, 2, 1])
        attn_logits = tf.matmul(queries, transposed_keys)
        if attn_mask is not None:
            # Transpose and tile the mask
            attn_logits += tf.tile(tf.squeeze(attn_mask, 1), [
                get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1,
                1
            ])

        # Compute the attention weights
        attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights')
        # Optionally apply dropout
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights,
                                             rate=self.dropout_attn,
                                             training=self.training)
        # Obtain context vectors
        weighted_memories = tf.matmul(attn_weights, values)
        return weighted_memories
    def _additive_attn(self, queries, keys, values, attn_mask):
        """ Uses additive attention to compute contextually enriched source-side representations. """
        # Account for beam-search
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.tile(keys, [num_beams, 1, 1])
        values = tf.tile(values, [num_beams, 1, 1])

        def _logits_fn(query):
            """ Computes time-step-wise attention scores. """
            query = tf.expand_dims(query, 1)
            return tf.reduce_sum(self.attn_weight * tf.nn.tanh(keys + query), axis=-1)

        # Obtain attention scores
        transposed_queries = tf.transpose(queries, [1, 0, 2])  # time-major
        attn_logits = tf.map_fn(_logits_fn, transposed_queries)
        attn_logits = tf.transpose(attn_logits, [1, 0, 2])

        if attn_mask is not None:
            # Transpose and tile the mask
            attn_logits += tf.tile(tf.squeeze(attn_mask, 1),
                                   [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1])

        # Compute the attention weights
        attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights')
        # Optionally apply dropout
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training)
        # Obtain context vectors
        weighted_memories = tf.matmul(attn_weights, values)
        return weighted_memories
예제 #5
0
def matmul_nd(nd_tensor, matrix):
    """ Performs matrix multiplication for n-dimensional inputs. """
    tensor_shape = tf_utils.get_shape_list(nd_tensor)
    matrix_shape = tf_utils.get_shape_list(matrix)

    initial_tensor_dims = tensor_shape[:-1]
    flat_first_dim = tf.reduce_prod(input_tensor=initial_tensor_dims)

    tensor_2d = tf.reshape(nd_tensor, [flat_first_dim, tensor_shape[-1]])
    result_2d = tf.matmul(tensor_2d, matrix)
    result_3d = tf.reshape(result_2d, initial_tensor_dims + [matrix_shape[-1]])
    return result_3d
예제 #6
0
 def call(self, inputs):
     """Implements call() for the layer."""
     input_shape = tf_utils.get_shape_list(inputs)
     flat_input = tf.reshape(inputs, [-1])
     output = tf.gather(self.embeddings, flat_input)
     output = tf.reshape(output, input_shape + [self.embedding_size])
     return output
예제 #7
0
 def _prepare_source():
     """ Pre-processes inputs to the encoder and generates the corresponding attention masks."""
     # Embed
     source_embeddings = self._embed(source_ids)
     # Obtain length and depth of the input tensors
     _, time_steps, depth = tf_utils.get_shape_list(source_embeddings)
     # Transform input mask into attention mask
     inverse_mask = tf.cast(tf.equal(source_mask, 0.0),
                            dtype=FLOAT_DTYPE)
     attn_mask = inverse_mask * -1e9
     # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits
     attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1)
     # Differentiate between self-attention and cross-attention masks for further, optional modifications
     self_attn_mask = attn_mask
     cross_attn_mask = attn_mask
     # Add positional encodings
     positional_signal = get_positional_signal(time_steps, depth,
                                               FLOAT_DTYPE)
     source_embeddings += positional_signal
     # Apply dropout
     if self.config.transformer_dropout_embeddings > 0:
         source_embeddings = tf.layers.dropout(
             source_embeddings,
             rate=self.config.transformer_dropout_embeddings,
             training=self.training)
     return source_embeddings, self_attn_mask, cross_attn_mask
예제 #8
0
def gather_indexes(sequence_tensor, positions):
  """Gathers the vectors at the specific positions.
  Args:
      sequence_tensor: Sequence output of `BertModel` layer of shape
        (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
        hidden units of `BertModel` layer.
      positions: Positions ids of tokens in sequence to mask for pretraining of
        with dimension (batch_size, max_predictions_per_seq) where
        `max_predictions_per_seq` is maximum number of tokens to mask out and
        predict per each sequence.
  Returns:
      Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
      num_hidden).
  """
  sequence_shape = tf_utils.get_shape_list(
      sequence_tensor, name='sequence_output_tensor')
  batch_size = sequence_shape[0]
  seq_length = sequence_shape[1]
  width = sequence_shape[2]

  flat_offsets = tf.keras.backend.reshape(
      tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
  flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
  flat_sequence_tensor = tf.keras.backend.reshape(
      sequence_tensor, [batch_size * seq_length, width])
  output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

  return output_tensor
예제 #9
0
        def _prepare_source():
            """ Pre-processes inputs to the encoder and generates the corresponding attention masks."""
            DICT_SIZE, ENG_DICT_FILE, OUTPUT_TRANSLATE_FILE, _, _, DEBIASED_EMBEDDING, _ = get_debias_files_from_config(
                self.consts_config_str)
            if self.USE_DEBIASED:
                print("using debiased embeddings")
                self.embedding_layer.embedding_table = self.embedding_matrix
            else:
                print("using non debiased embeddings")
            source_embeddings = self._embed(source_ids)
            if self.COLLECT_EMBEDDING_TABLE:
                ## print the embedding table
                # ########################################### PRINT #########################################################
                printops = []
                printops.append(
                    tf.compat.v1.Print(
                        [], [tf.shape(self.embedding_layer.embedding_table)],
                        "embedding_table shape ",
                        summarize=10000))
                for i in list(range(DICT_SIZE)):
                    printops.append(
                        tf.compat.v1.Print(
                            [], [self.embedding_layer.embedding_table[i, :]],
                            "enc_inputs for word " + str(i),
                            summarize=10000))
                    printops.append(
                        tf.compat.v1.Print(
                            [], [],
                            "**************************************",
                            summarize=10000))
                    tf.io.write_file(
                        "output_translate.txt",
                        str(self.embedding_layer.embedding_table[i, :]))
                with tf.control_dependencies(printops):
                    source_embeddings = source_embeddings * 1
                # ###########################################################################################################

            # Embed
            ### comment: first embedding without positional signal
            # Obtain length and depth of the input tensors
            _, time_steps, depth = tf_utils.get_shape_list(source_embeddings)
            # Transform input mask into attention mask
            inverse_mask = tf.cast(tf.equal(source_mask, 0.0),
                                   dtype=FLOAT_DTYPE)
            attn_mask = inverse_mask * -1e9
            # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits
            attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1)
            # Differentiate between self-attention and cross-attention masks for further, optional modifications
            self_attn_mask = attn_mask
            cross_attn_mask = attn_mask
            # Add positional encodings
            positional_signal = get_positional_signal(time_steps, depth,
                                                      FLOAT_DTYPE)
            source_embeddings += positional_signal  ### comment: first embedding with positional signal

            # Apply dropout
            if self.dropout_embedding is not None:
                source_embeddings = self.dropout_embedding(
                    source_embeddings, training=self.training)
            return source_embeddings, self_attn_mask, cross_attn_mask
    def _dot_product_attn(self, queries, keys, values, attn_mask, scaling_on):
        """ Defines the dot-product attention function; see Vasvani et al.(2017), Eq.(1). """
        # query/ key/ value have shape = [batch_size, time_steps, num_heads, num_features]
        # Tile keys and values tensors to match the number of decoding beams; ignored if already done by fusion module
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.cond(pred=tf.greater(num_beams, 1),
                       true_fn=lambda: tf.tile(keys, [num_beams, 1, 1, 1]),
                       false_fn=lambda: keys)
        values = tf.cond(pred=tf.greater(num_beams, 1),
                         true_fn=lambda: tf.tile(values, [num_beams, 1, 1, 1]),
                         false_fn=lambda: values)

        # Transpose split inputs
        queries = tf.transpose(a=queries, perm=[0, 2, 1, 3])
        values = tf.transpose(a=values, perm=[0, 2, 1, 3])
        attn_logits = tf.matmul(queries, tf.transpose(a=keys,
                                                      perm=[0, 2, 3, 1]))

        # Scale attention_logits by key dimensions to prevent softmax saturation, if specified
        if scaling_on:
            key_dims = get_shape_list(keys)[-1]
            normalizer = tf.sqrt(tf.cast(key_dims, self.float_dtype))
            attn_logits /= normalizer

        # Optionally mask out positions which should not be attended to
        # attention mask should have shape=[batch, num_heads, query_length, key_length]
        # attn_logits has shape=[batch, num_heads, query_length, key_length]
        if attn_mask is not None:
            attn_mask = tf.cond(
                pred=tf.greater(num_beams, 1),
                true_fn=lambda: tf.tile(attn_mask, [num_beams, 1, 1, 1]),
                false_fn=lambda: attn_mask)
            attn_logits += attn_mask

        # Calculate attention weights
        attn_weights = tf.nn.softmax(attn_logits)
        # Optionally apply dropout:
        if self.dropout_attn is not None:
            attn_weights = self.dropout_attn(attn_weights,
                                             training=self.training)
        # Optionally apply DropHead:
        if self.drophead is not None:
            attn_weights = self.drophead(attn_weights, training=self.training)
        # Weigh attention values
        weighted_memories = tf.matmul(attn_weights, values)
        return weighted_memories
 def _merge_from_heads(self, split_inputs):
     """ Inverts the _split_among_heads operation. """
     # Transpose split_inputs to perform the merge along the last two dimensions of the split input
     split_inputs = tf.transpose(split_inputs, [0, 2, 1, 3])
     # Retrieve the depth of the tensor to be merged
     split_inputs_dims = get_shape_list(split_inputs)
     split_inputs_depth = split_inputs_dims[-1]
     # Merge the depth and num_heads dimensions of split_inputs
     merged_inputs = tf.reshape(split_inputs, split_inputs_dims[:-2] + [self.num_heads * split_inputs_depth])
     return merged_inputs
예제 #12
0
 def gather_attn(attn):
     # TODO Specify second and third?
     shapes = {attn: ('batch_size', None, None)}
     tf_utils.assert_shapes(shapes)
     attn_dims = tf_utils.get_shape_list(attn)
     new_shape = [beam_size, batch_size_x] + attn_dims[1:]
     tmp = tf.reshape(attn, new_shape)
     flat_tensor = tf.transpose(a=tmp, perm=[1, 0, 2, 3])
     tmp = tf.gather_nd(flat_tensor, gather_coordinates)
     tmp = tf.transpose(a=tmp, perm=[1, 0, 2, 3])
     gathered_values = tf.reshape(tmp, attn_dims)
     return gathered_values
    def _split_among_heads(self, inputs):
        """ Splits the attention inputs among multiple heads. """
        # Retrieve the depth of the input tensor to be split (input is 3d)
        inputs_dims = get_shape_list(inputs)
        inputs_depth = inputs_dims[-1]

        # Assert the depth is compatible with the specified number of attention heads
        if isinstance(inputs_depth, int) and isinstance(self.num_heads, int):
            assert inputs_depth % self.num_heads == 0, \
                ('Attention inputs depth {:d} is not evenly divisible by the specified number of attention heads {:d}'
                 .format(inputs_depth, self.num_heads))
        split_inputs = tf.reshape(inputs, inputs_dims[:-1] + [self.num_heads, inputs_depth // self.num_heads])
        return split_inputs
예제 #14
0
    def get_memory_invariants(self, memories):
        """Generate shape invariants for memories.

        Args:
            memories: dictionary (see top-level class description)

        Returns:
            Dictionary of shape invariants with same structure as memories.
        """
        with tf.compat.v1.name_scope(self._scope):
            invariants = dict()
            for layer_id in memories.keys():
                layer_mems = memories[layer_id]
                invariants[layer_id] = {
                    key: tf.TensorShape(
                        [None] * len(tf_utils.get_shape_list(layer_mems[key])))
                    for key in layer_mems.keys()
                }
            return invariants
예제 #15
0
    def get_memory_invariants(self, memories):
        """Generate shape invariants for memories.

        Args:
            memories: dictionary (see top-level class description)

        Returns:
            Dictionary of shape invariants with same structure as memories.
        """
        with tf.name_scope(self._scope):
            d = self._model.decoder

            high_depth = 0 if d.high_gru_stack is None \
                           else len(d.high_gru_stack.grus)

            num_dims = len(tf_utils.get_shape_list(memories['base_states']))
            # TODO Specify shape in full?
            partial_shape = tf.TensorShape([None] * num_dims)

            invariants = {}
            invariants['base_states'] = partial_shape
            invariants['high_states'] = [partial_shape] * high_depth
            return invariants
예제 #16
0
    def call(self, inputs, **kwargs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        word_embeddings = unpacked_inputs[0]
        token_type_ids = unpacked_inputs[1]
        input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = word_embeddings
        if self.use_type_embeddings:
            flat_token_type_ids = tf.reshape(token_type_ids, [-1])
            one_hot_ids = tf.one_hot(flat_token_type_ids,
                                     depth=self.token_type_vocab_size,
                                     dtype=self.dtype)
            token_type_embeddings = tf.matmul(one_hot_ids,
                                              self.type_embeddings)
            token_type_embeddings = tf.reshape(token_type_embeddings,
                                               [batch_size, seq_length, width])
            output += token_type_embeddings

        if self.use_position_embeddings:
            position_embeddings = tf.expand_dims(tf.slice(
                self.position_embeddings, [0, 0], [seq_length, width]),
                                                 axis=0)

            output += position_embeddings

        output = self.output_layer_norm(output)
        output = self.output_dropout(output,
                                     training=kwargs.get('training', False))

        projected_output = self.projection(output)

        return projected_output