コード例 #1
0
 def _prepare_source():
     """ Pre-processes inputs to the encoder and generates the corresponding attention masks."""
     # Embed
     pre_source_embeddings = self._embed(source_ids)
     with tf.variable_scope(self.name):
         source_embeddings = self.emb_ffn.forward(pre_source_embeddings)
     glove_embeddings = self.embedding_layer.get_glove_embed(source_pids)
     source_embeddings += glove_embeddings
     # Obtain length and depth of the input tensors
     _, time_steps, depth = get_shape_list(source_embeddings)
     # Transform input mask into attention mask
     # 恢复source_mask
     shape_mask = get_shape_list(source_mask)
     source_mask1 = tf.slice(source_mask, [0, 0, 0], [shape_mask[0], shape_mask[1], 1])
     source_mask2 = tf.reshape(source_mask1, [shape_mask[0], shape_mask[1]])
     inverse_mask = tf.cast(tf.equal(source_mask2, 0.0), dtype=self.float_dtype)
     attn_mask = inverse_mask * -1e9
     # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits
     attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1)
     # Differentiate between self-attention and cross-attention masks for further, optional modifications
     self_attn_mask = attn_mask
     cross_attn_mask = attn_mask
     # Add positional encodings
     positional_signal = get_positional_signal(time_steps, depth, self.float_dtype)
     source_embeddings += positional_signal
     # Apply dropout
     if self.config.transformer_dropout_embeddings > 0:
         source_embeddings = tf.layers.dropout(source_embeddings,
                                               rate=self.config.transformer_dropout_embeddings, training=self.training)
     return source_embeddings, self_attn_mask, cross_attn_mask
コード例 #2
0
    def _additive_attn(self, queries, keys, values, attn_mask):
        """ Uses additive attention to compute contextually enriched source-side representations. """
        # Account for beam-search
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.tile(keys, [num_beams, 1, 1])
        values = tf.tile(values, [num_beams, 1, 1])

        def _logits_fn(query):
            """ Computes time-step-wise attention scores. """
            query = tf.expand_dims(query, 1)
            return tf.reduce_sum(self.attn_weight * tf.nn.tanh(keys + query),
                                 axis=-1)

        # Obtain attention scores
        transposed_queries = tf.transpose(queries, [1, 0, 2])  # time-major
        attn_logits = tf.map_fn(_logits_fn, transposed_queries)
        attn_logits = tf.transpose(attn_logits, [1, 0, 2])

        if attn_mask is not None:
            # Transpose and tile the mask
            attn_logits += tf.tile(tf.squeeze(attn_mask, 1), [
                get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1,
                1
            ])

        # Compute the attention weights
        attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights')
        # Optionally apply dropout
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights,
                                             rate=self.dropout_attn,
                                             training=self.training)
        # Obtain context vectors
        weighted_memories = tf.matmul(attn_weights, values)
        return weighted_memories
コード例 #3
0
    def _multiplicative_attn(self, queries, keys, values, attn_mask):
        """ Uses multiplicative attention to compute contextually enriched source-side representations. """
        # Account for beam-search
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.tile(keys, [num_beams, 1, 1])
        values = tf.tile(values, [num_beams, 1, 1])

        def _logits_fn(query):
            """ Computes time-step-wise attention scores. """
            query = tf.expand_dims(query, 1)
            return tf.multiply(keys, query)

        # Obtain attention scores
        transposed_queries = tf.transpose(queries, [1, 0, 2])  # time-major
        # attn_logits has shape=[time_steps_q, batch_size, time_steps_k, num_features]
        attn_logits = tf.map_fn(_logits_fn, transposed_queries)

        if attn_mask is not None:
            transposed_mask = \
                tf.transpose(tf.tile(attn_mask, [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1, 1]),
                             [2, 0, 3, 1])
            attn_logits += transposed_mask

        # Compute the attention weights
        attn_weights = tf.nn.softmax(attn_logits, axis=-2, name='attn_weights')
        # Optionally apply dropout
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training)

        # Obtain context vectors
        expanded_values = tf.expand_dims(values, axis=1)
        weighted_memories = \
            tf.reduce_sum(tf.multiply(tf.transpose(attn_weights, [1, 0, 2, 3]), expanded_values), axis=2)
        return weighted_memories
コード例 #4
0
    def _multiplicative_attn(self, queries, keys, values, attn_mask):
        """ Uses multiplicative attention to compute contextually enriched source-side representations. """
        # Account for beam-search
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.tile(keys, [num_beams, 1, 1])
        values = tf.tile(values, [num_beams, 1, 1])

        # Use multiplicative attention
        transposed_keys = tf.transpose(keys, [0, 2, 1])
        attn_logits = tf.matmul(queries, transposed_keys)
        if attn_mask is not None:
            # Transpose and tile the mask
            attn_logits += tf.tile(tf.squeeze(attn_mask, 1), [
                get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1,
                1
            ])

        # Compute the attention weights
        attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights')
        # Optionally apply dropout
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights,
                                             rate=self.dropout_attn,
                                             training=self.training)
        # Obtain context vectors
        weighted_memories = tf.matmul(attn_weights, values)
        return weighted_memories
コード例 #5
0
def gather_memories(memory_dict, gather_coordinates):
    """ Gathers layer-wise memory tensors corresponding to top sequences from the provided memory dictionary
    during beam search. """
    # Initialize dicts
    gathered_memories = dict()
    # Get coordinate shapes
    coords_dims = get_shape_list(gather_coordinates)

    # Gather
    for layer_key in memory_dict.keys():
        layer_dict = memory_dict[layer_key]
        gathered_memories[layer_key] = dict()

        for attn_key in layer_dict.keys():
            attn_tensor = layer_dict[attn_key]
            attn_dims = get_shape_list(attn_tensor)
            # Not sure if this is faster than the 'memory-less' version
            flat_tensor = \
                tf.transpose(tf.reshape(attn_tensor, [-1, coords_dims[0]] + attn_dims[1:]), [1, 0, 2, 3])
            gathered_values = tf.reshape(
                tf.transpose(tf.gather_nd(flat_tensor, gather_coordinates),
                             [1, 0, 2, 3]),
                [tf.multiply(coords_dims[1], coords_dims[0])] + attn_dims[1:])
            gathered_memories[layer_key][attn_key] = gathered_values

    return gathered_memories
コード例 #6
0
    def _additive_attn(self, queries, keys, values, attn_mask):
        """ Uses additive attention to compute contextually enriched source-side representations. """
        # Account for beam-search
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.tile(keys, [num_beams, 1, 1])
        values = tf.tile(values, [num_beams, 1, 1])

        def _logits_fn(query):
            """ Computes time-step-wise attention scores. """
            query = tf.expand_dims(query, 1)
            return tf.reduce_sum(self.attn_weight * tf.nn.tanh(keys + query), axis=-1)

        # Obtain attention scores
        transposed_queries = tf.transpose(queries, [1, 0, 2])  # time-major
        attn_logits = tf.map_fn(_logits_fn, transposed_queries)
        attn_logits = tf.transpose(attn_logits, [1, 0, 2])

        if attn_mask is not None:
            # Transpose and tile the mask
            attn_logits += tf.tile(tf.squeeze(attn_mask, 1),
                                   [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1])

        # Compute the attention weights
        attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights')
        # Optionally apply dropout
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training)
        # Obtain context vectors
        weighted_memories = tf.matmul(attn_weights, values)
        return weighted_memories
コード例 #7
0
 def _embed(self, index_sequence):
     """ Embeds source-side indices to obtain the corresponding dense tensor representations. """
     #重要更改
     #index_sequence: (batch_size, seq_len, u_len)
     u_emb = self.embedding_layer.embed(index_sequence)  #(batch_size, seq_len, u_len, embedding_size)
     shape = get_shape_list(u_emb)
     #加上位置编码,特指md5:[1, u_len, embedding_size]
     if self.config.utf8_type == "md5":
         md5_positional_signal = get_positional_signal(shape[2], shape[3], self.float_dtype)
         u_emb += md5_positional_signal
     #修剪为2048
     input_size = self.config.pre_source_embedding_size  # 默认2048
     cc = input_size - shape[2]*shape[3]
     if self.config.pre_source_embed_cross: #似乎效果更差,且测试时bleu值异常
         embsize = tf.to_int32((input_size/shape[2]))
         accsize = input_size % shape[2]
         fix_merge_emb = tf.pad(u_emb, [[0, 0], [0, 0], [0, 0], [0, tf.reduce_max([embsize-shape[3], 0])]], constant_values=1.0)
         fix_merge_emb = tf.slice(fix_merge_emb, [0, 0, 0, 0], [-1, -1, -1, embsize])
         fix_merge_emb = tf.reshape(fix_merge_emb, [shape[0], shape[1], shape[2]*embsize])
         fix_merge_emb = tf.pad(fix_merge_emb, [[0, 0], [0, 0], [0, accsize]], constant_values=1.0)
     else:
         merge_emb = tf.reshape(u_emb, [shape[0], shape[1], shape[2]*shape[3]])  #(batch_size, seq_len, u_len*embedding_size)
         fix_merge_emb = tf.pad(merge_emb, [[0, 0], [0, 0], [0, tf.reduce_max([cc, 0])]], constant_values=0)
         fix_merge_emb = tf.slice(fix_merge_emb, [0, 0, 0], [-1, -1, input_size])
     
     return fix_merge_emb
コード例 #8
0
 def _prepare_source():
     """ Pre-processes inputs to the encoder and generates the corresponding attention masks."""
     # Embed
     source_embeddings = self._embed(source_ids)
     # Obtain length and depth of the input tensors
     _, time_steps, depth = get_shape_list(source_embeddings)
     # Transform input mask into attention mask
     inverse_mask = tf.cast(tf.equal(source_mask, 0.0),
                            dtype=FLOAT_DTYPE)
     attn_mask = inverse_mask * -1e9
     # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits
     attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1)
     # Differentiate between self-attention and cross-attention masks for further, optional modifications
     self_attn_mask = attn_mask
     cross_attn_mask = attn_mask
     # Add positional encodings
     positional_signal = get_positional_signal(time_steps, depth,
                                               FLOAT_DTYPE)
     source_embeddings += positional_signal
     # Apply dropout
     if self.config.transformer_dropout_embeddings > 0:
         source_embeddings = tf.layers.dropout(
             source_embeddings,
             rate=self.config.transformer_dropout_embeddings,
             training=self.training)
     return source_embeddings, self_attn_mask, cross_attn_mask
コード例 #9
0
    def _dot_product_attn(self, queries, keys, values, attn_mask, scaling_on):
        """ Defines the dot-product attention function; see Vasvani et al.(2017), Eq.(1). """
        # query/ key/ value have shape = [batch_size, time_steps, num_heads, num_features]
        # Tile keys and values tensors to match the number of decoding beams; ignored if already done by fusion module
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.cond(tf.greater(num_beams, 1),
                       lambda: tf.tile(keys, [num_beams, 1, 1, 1]),
                       lambda: keys)
        values = tf.cond(tf.greater(num_beams, 1),
                         lambda: tf.tile(values, [num_beams, 1, 1, 1]),
                         lambda: values)

        # Transpose split inputs
        queries = tf.transpose(queries, [0, 2, 1, 3])
        values = tf.transpose(values, [0, 2, 1, 3])
        attn_logits = tf.matmul(queries, tf.transpose(keys, [0, 2, 3, 1]))

        # Scale attention_logits by key dimensions to prevent softmax saturation, if specified
        if scaling_on:
            key_dims = get_shape_list(keys)[-1]
            normalizer = tf.sqrt(tf.cast(key_dims, self.float_dtype))
            attn_logits /= normalizer

        # Optionally mask out positions which should not be attended to
        # attention mask should have shape=[batch, num_heads, query_length, key_length]
        # attn_logits has shape=[batch, num_heads, query_length, key_length]
        if attn_mask is not None:
            attn_mask = tf.cond(
                tf.greater(num_beams, 1),
                lambda: tf.tile(attn_mask, [num_beams, 1, 1, 1]),
                lambda: attn_mask)
            attn_logits += attn_mask

        # Calculate attention weights
        attn_weights = tf.nn.softmax(attn_logits)
        # Optionally apply dropout:
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights,
                                             rate=self.dropout_attn,
                                             training=self.training)
        # Weigh attention values
        weighted_memories = tf.matmul(attn_weights, values)
        return weighted_memories
コード例 #10
0
 def _merge_from_heads(self, split_inputs):
     """ Inverts the _split_among_heads operation. """
     # Transpose split_inputs to perform the merge along the last two dimensions of the split input
     split_inputs = tf.transpose(split_inputs, [0, 2, 1, 3])
     # Retrieve the depth of the tensor to be merged
     split_inputs_dims = get_shape_list(split_inputs)
     split_inputs_depth = split_inputs_dims[-1]
     # Merge the depth and num_heads dimensions of split_inputs
     merged_inputs = tf.reshape(split_inputs, split_inputs_dims[:-2] + [self.num_heads * split_inputs_depth])
     return merged_inputs
コード例 #11
0
 def _merge_from_heads(self, split_inputs):
     """ Inverts the _split_among_heads operation. """
     # Transpose split_inputs to perform the merge along the last two dimensions of the split input
     split_inputs = tf.transpose(split_inputs, [0, 2, 1, 3])
     # Retrieve the depth of the tensor to be merged
     split_inputs_dims = get_shape_list(split_inputs)
     split_inputs_depth = split_inputs_dims[-1]
     # Merge the depth and num_heads dimensions of split_inputs
     merged_inputs = tf.reshape(
         split_inputs,
         split_inputs_dims[:-2] + [self.num_heads * split_inputs_depth])
     return merged_inputs
コード例 #12
0
def get_memory_invariants(memories):
    """ Calculates the invariant shapes for the model memories (i.e. states of th RNN ar layer-wise attentions of the
    transformer). """
    memory_type = type(memories)
    if memory_type == dict:
        memory_invariants = dict()
        for layer_id in memories.keys():
            memory_invariants[layer_id] = {key: tf.TensorShape([None] * len(get_shape_list(memories[layer_id][key])))
                                           for key in memories[layer_id].keys()}
    else:
        raise ValueError('Memory type not supported, must be a dictionary.')
    return memory_invariants
コード例 #13
0
 def gather_attn(attn):
     # TODO Specify second and third?
     shapes = {attn: ('batch_size', None, None)}
     tf_utils.assert_shapes(shapes)
     attn_dims = get_shape_list(attn)
     new_shape = [beam_size, batch_size_x] + attn_dims[1:]
     tmp = tf.reshape(attn, new_shape)
     flat_tensor = tf.transpose(tmp, [1, 0, 2, 3])
     tmp = tf.gather_nd(flat_tensor, gather_coordinates)
     tmp = tf.transpose(tmp, [1, 0, 2, 3])
     gathered_values = tf.reshape(tmp, attn_dims)
     return gathered_values
コード例 #14
0
 def _pre_embed(self, index_sequence):
     u_emb = self.embedding_layer.embed(index_sequence) #(batch_size, u_len, embedding_size)
     shape = get_shape_list(u_emb)
     if self.config.utf8_type == "md5":
         md5_positional_signal = get_positional_signal(shape[1], shape[2], self.float_dtype)
         u_emb += md5_positional_signal
     input_size = self.config.pre_source_embedding_size
     cc = input_size - shape[1]*shape[2]
     merge_emb = tf.reshape(u_emb, [shape[0], shape[1]*shape[2]])
     #merge_emb: (batch_size, u_len*embedding_size)
     fix_merge_emb = tf.pad(merge_emb, [[0, 0], [0, tf.reduce_max([cc, 0])]], constant_values=1.0)
     fix_merge_emb = tf.slice(fix_merge_emb, [0, 0], [-1, input_size])
     return fix_merge_emb
コード例 #15
0
    def _split_among_heads(self, inputs):
        """ Splits the attention inputs among multiple heads. """
        # Retrieve the depth of the input tensor to be split (input is 3d)
        inputs_dims = get_shape_list(inputs)
        inputs_depth = inputs_dims[-1]

        # Assert the depth is compatible with the specified number of attention heads
        if isinstance(inputs_depth, int) and isinstance(self.num_heads, int):
            assert inputs_depth % self.num_heads == 0, \
                ('Attention inputs depth {:d} is not evenly divisible by the specified number of attention heads {:d}'
                 .format(inputs_depth, self.num_heads))
        split_inputs = tf.reshape(inputs, inputs_dims[:-1] + [self.num_heads, inputs_depth // self.num_heads])
        return split_inputs
コード例 #16
0
    def _multiplicative_attn(self, queries, keys, values, attn_mask):
        """ Uses multiplicative attention to compute contextually enriched source-side representations. """
        # Account for beam-search
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.tile(keys, [num_beams, 1, 1])
        values = tf.tile(values, [num_beams, 1, 1])

        # Use multiplicative attention
        transposed_keys = tf.transpose(keys, [0, 2, 1])
        attn_logits = tf.matmul(queries, transposed_keys)
        if attn_mask is not None:
            # Transpose and tile the mask
            attn_logits += tf.tile(tf.squeeze(attn_mask, 1),
                                   [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1])

        # Compute the attention weights
        attn_weights = tf.nn.softmax(attn_logits, axis=-1, name='attn_weights')
        # Optionally apply dropout
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training)
        # Obtain context vectors
        weighted_memories = tf.matmul(attn_weights, values)
        return weighted_memories
コード例 #17
0
def decode_greedy(model, do_sample=False, beam_size=0,
                  normalization_alpha=None):
    # Determine size of current batch
    batch_size, _, _ = get_shape_list(model.source_ids)
    # Encode source sequences
    with tf.name_scope('{:s}_encode'.format(model.name)):
        enc_output, cross_attn_mask = model.enc.encode(model.source_pids,
						       model.source_ids,
                                                       model.source_mask)
    # Decode into target sequences
    with tf.name_scope('{:s}_decode'.format(model.name)):
        dec_output, scores = decode_at_test(model, model.dec, enc_output,
            cross_attn_mask, batch_size, beam_size, do_sample, normalization_alpha)
    return dec_output, scores
コード例 #18
0
    def _multiplicative_attn(self, queries, keys, values, attn_mask):
        """ Uses multiplicative attention to compute contextually enriched source-side representations. """
        # Account for beam-search
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.tile(keys, [num_beams, 1, 1])
        values = tf.tile(values, [num_beams, 1, 1])

        def _logits_fn(query):
            """ Computes time-step-wise attention scores. """
            query = tf.expand_dims(query, 1)
            return tf.multiply(keys, query)

        # Obtain attention scores
        transposed_queries = tf.transpose(queries, [1, 0, 2])  # time-major
        # attn_logits has shape=[time_steps_q, batch_size, time_steps_k, num_features]
        attn_logits = tf.map_fn(_logits_fn, transposed_queries)

        if attn_mask is not None:
            transposed_mask = \
                tf.transpose(tf.tile(attn_mask, [get_shape_list(queries)[0] // get_shape_list(attn_mask)[0], 1, 1, 1]),
                             [2, 0, 3, 1])
            attn_logits += transposed_mask

        # Compute the attention weights
        attn_weights = tf.nn.softmax(attn_logits, axis=-2, name='attn_weights')
        # Optionally apply dropout
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights,
                                             rate=self.dropout_attn,
                                             training=self.training)

        # Obtain context vectors
        expanded_values = tf.expand_dims(values, axis=1)
        weighted_memories = \
            tf.reduce_sum(tf.multiply(tf.transpose(attn_weights, [1, 0, 2, 3]), expanded_values), axis=2)
        return weighted_memories
コード例 #19
0
    def _dot_product_attn(self, queries, keys, values, attn_mask, scaling_on):
        """ Defines the dot-product attention function; see Vasvani et al.(2017), Eq.(1). """
        # query/ key/ value have shape = [batch_size, time_steps, num_heads, num_features]
        # Tile keys and values tensors to match the number of decoding beams; ignored if already done by fusion module
        num_beams = get_shape_list(queries)[0] // get_shape_list(keys)[0]
        keys = tf.cond(tf.greater(num_beams, 1), lambda: tf.tile(keys, [num_beams, 1, 1, 1]), lambda: keys)
        values = tf.cond(tf.greater(num_beams, 1), lambda: tf.tile(values, [num_beams, 1, 1, 1]), lambda: values)

        # Transpose split inputs
        queries = tf.transpose(queries, [0, 2, 1, 3])
        values = tf.transpose(values, [0, 2, 1, 3])
        attn_logits = tf.matmul(queries, tf.transpose(keys, [0, 2, 3, 1]))

        # Scale attention_logits by key dimensions to prevent softmax saturation, if specified
        if scaling_on:
            key_dims = get_shape_list(keys)[-1]
            normalizer = tf.sqrt(tf.cast(key_dims, self.float_dtype))
            attn_logits /= normalizer

        # Optionally mask out positions which should not be attended to
        # attention mask should have shape=[batch, num_heads, query_length, key_length]
        # attn_logits has shape=[batch, num_heads, query_length, key_length]
        if attn_mask is not None:
            attn_mask = tf.cond(tf.greater(num_beams, 1),
                                lambda: tf.tile(attn_mask, [num_beams, 1, 1, 1]),
                                lambda: attn_mask)
            attn_logits += attn_mask

        # Calculate attention weights
        attn_weights = tf.nn.softmax(attn_logits)
        # Optionally apply dropout:
        if self.dropout_attn > 0.0:
            attn_weights = tf.layers.dropout(attn_weights, rate=self.dropout_attn, training=self.training)
        # Weigh attention values
        weighted_memories = tf.matmul(attn_weights, values)
        return weighted_memories
コード例 #20
0
    def _split_among_heads(self, inputs):
        """ Splits the attention inputs among multiple heads. """
        # Retrieve the depth of the input tensor to be split (input is 3d)
        inputs_dims = get_shape_list(inputs)
        inputs_depth = inputs_dims[-1]

        # Assert the depth is compatible with the specified number of attention heads
        if isinstance(inputs_depth, int) and isinstance(self.num_heads, int):
            assert inputs_depth % self.num_heads == 0, \
                ('Attention inputs depth {:d} is not evenly divisible by the specified number of attention heads {:d}'
                 .format(inputs_depth, self.num_heads))
        split_inputs = tf.reshape(
            inputs, inputs_dims[:-1] +
            [self.num_heads, inputs_depth // self.num_heads])
        return split_inputs
コード例 #21
0
    def get_memory_invariants(self, memories):
        """Generate shape invariants for memories.

        Args:
            memories: dictionary (see top-level class description)

        Returns:
            Dictionary of shape invariants with same structure as memories.
        """
        with tf.name_scope(self._scope):
            invariants = dict()
            for layer_id in memories.keys():
                layer_mems = memories[layer_id]
                invariants[layer_id] = {
                    key: tf.TensorShape([None] *
                                        len(get_shape_list(layer_mems[key])))
                    for key in layer_mems.keys()
                }
            return invariants
コード例 #22
0
ファイル: transformer.py プロジェクト: rsennrich/nematus
 def _prepare_source():
     """ Pre-processes inputs to the encoder and generates the corresponding attention masks."""
     # Embed
     source_embeddings = self._embed(source_ids)
     # Obtain length and depth of the input tensors
     _, time_steps, depth = get_shape_list(source_embeddings)
     # Transform input mask into attention mask
     inverse_mask = tf.cast(tf.equal(source_mask, 0.0), dtype=self.float_dtype)
     attn_mask = inverse_mask * -1e9
     # Expansion to shape [batch_size, 1, 1, time_steps] is needed for compatibility with attention logits
     attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, 1), 1)
     # Differentiate between self-attention and cross-attention masks for further, optional modifications
     self_attn_mask = attn_mask
     cross_attn_mask = attn_mask
     # Add positional encodings
     positional_signal = get_positional_signal(time_steps, depth, self.float_dtype)
     source_embeddings += positional_signal
     # Apply dropout
     if self.config.transformer_dropout_embeddings > 0:
         source_embeddings = tf.layers.dropout(source_embeddings,
                                               rate=self.config.transformer_dropout_embeddings, training=self.training)
     return source_embeddings, self_attn_mask, cross_attn_mask
コード例 #23
0
def get_memory_invariants(model_memories):
    """Calculates the invariant shapes for a model's memories.

    'Memories' are states of the RNN or layer-wise attentions of the
    transformer.

    Args:
        model_memories: a dict of dicts.

    Returns:
        An invariant dictionary for the model.
    """
    if type(model_memories) != dict:
        raise ValueError('Memory type not supported, must be a dictionary.')
    invariants = dict()
    for layer_id in model_memories.keys():
        layer_mems = model_memories[layer_id]
        invariants[layer_id] = {
            key: tf.TensorShape([None]*len(get_shape_list(layer_mems[key])))
            for key in model_memories[layer_id].keys()
        }
    return invariants
コード例 #24
0
    def get_memory_invariants(self, memories):
        """Generate shape invariants for memories.

        Args:
            memories: dictionary (see top-level class description)

        Returns:
            Dictionary of shape invariants with same structure as memories.
        """
        with tf.name_scope(self._scope):
            d = self._model.decoder

            high_depth = 0 if d.high_gru_stack is None \
                           else len(d.high_gru_stack.grus)

            num_dims = len(get_shape_list(memories['base_states']))
            # TODO Specify shape in full?
            partial_shape = tf.TensorShape([None] * num_dims)

            invariants = {}
            invariants['base_states'] = partial_shape
            invariants['high_states'] = [partial_shape] * high_depth
            return invariants
コード例 #25
0
def decode_greedy(models, do_sample=False, beam_size=0,
                  normalization_alpha=None):
    """Decodes a source sequence using beam search or sampling.

    Args:
        models: a list of Transformer objects.
        do_sample: randomly sample instead of argmax for greedy search
        beam_size: integer specifying the beam width.
        normalization_alpha: length normalization hyperparameter.

    Returns:
        A tuple (ids, scores), where ids is a Tensor with shape (batch_size, k,
        max_seq_len) containing k translations for each input sentence in
        model.inputs.x and scores is a Tensor with shape (batch_size, k)
    """

    # Get some parameter values. For ensembling, some settings are required to
    # be consistent across all models but others are not.  In the former case,
    # we assume that consistency has already been checked.  For the parameters
    # that are allowed to vary across models, the first model's settings take
    # precedence.
    batch_size, _ = get_shape_list(models[0].source_ids)
    model_name = models[0].name
    decoder_name = models[0].dec.name
    from_rnn = models[0].dec.from_rnn
    config = models[0].dec.config
    float_dtype = models[0].dec.float_dtype
    int_dtype = models[0].dec.int_dtype
    vocab_size = models[0].dec.embedding_layer.get_vocab_size(),

    # Generate a positional signal for the longest possible output.
    with tf.name_scope('{:s}_decode'.format(model_name)):
        with tf.variable_scope(decoder_name):
            positional_signal = get_positional_signal(
                config.translation_maxlen,
                config.embedding_size,
                float_dtype)

    # Generate a decoding function for each model.
    decoding_functions = []
    for model in models:
        assert model.name == model_name

        # Encode source sequences.
        with tf.name_scope('{:s}_encode'.format(model.name)):
            enc_output, cross_attn_mask = model.enc.encode(model.source_ids,
                                                           model.source_mask)

        # Generate a model-specific decoding function.
        with tf.name_scope('{:s}_decode'.format(model.name)):
            func = generate_decoding_function(enc_output, cross_attn_mask,
                                              model.dec, positional_signal)
            decoding_functions.append(func)

    # Decode into target sequences
    with tf.name_scope('{:s}_decode'.format(model_name)):
        with tf.variable_scope(decoder_name):

            if beam_size > 0:
                # Initialize target IDs with <GO>
                initial_ids = tf.cast(tf.fill([batch_size], 1), dtype=int_dtype)
                initial_memories = [
                    model.dec._get_initial_memories(batch_size,
                                                    beam_size=beam_size)
                    for model in models]
                output_sequences, scores = _beam_search(
                    decoding_functions,
                    initial_ids,
                    initial_memories,
                    int_dtype,
                    float_dtype,
                    config.translation_maxlen,
                    batch_size,
                    beam_size,
                    vocab_size,
                    0,
                    normalization_alpha)

            else:
                # Initialize target IDs with <GO>
                initial_ids = tf.cast(tf.fill([batch_size, 1], 1),
                                      dtype=int_dtype)
                initial_memories = [
                    model.dec._get_initial_memories(batch_size, beam_size=1)
                    for model in models]
                output_sequences, scores = greedy_search(
                    models[0],
                    decoding_functions[0],
                    initial_ids,
                    initial_memories[0],
                    int_dtype,
                    float_dtype,
                    config.translation_maxlen,
                    batch_size,
                    0,
                    do_sample,
                    time_major=False)

    return output_sequences, scores