示例#1
0
    def __init__(self, embedding=None, vocab_size=None, hparams=None):
        ModuleBase.__init__(self, hparams)
        self._vocab_size = vocab_size
        self._embedding = None
        self.sampling_method = self._hparams.sampling_method
        with tf.variable_scope(self.variable_scope):
            if self._hparams.initializer:
                tf.get_variable_scope().set_initializer( \
                    layers.get_initializer(self._hparams.initializer))
            if self._hparams.position_embedder.name == 'sinusoids':
                self.position_embedder = \
                    position_embedders.SinusoidsSegmentalPositionEmbedder( \
                    self._hparams.position_embedder.hparams)

        if self._hparams.use_embedding:
            if embedding is None and vocab_size is None:
                raise ValueError("""If 'embedding' is not provided,
                    'vocab_size' must be specified.""")
            if isinstance(embedding, (tf.Tensor, tf.Variable)):
                self._embedding = embedding
            else:
                self._embedding = embedder_utils.get_embedding(
                    self._hparams.embedding,
                    embedding,
                    vocab_size,
                    variable_scope=self.variable_scope)
                self._embed_dim = shape_list(self._embedding)[-1]
                if self._hparams.zero_pad:
                    self._embedding = tf.concat( \
                        (tf.zeros(shape=[1, self._embed_dim]),\
                        self._embedding[1:, :]), 0)
            if self._vocab_size is None:
                self._vocab_size = self._embedding.get_shape().as_list()[0]
        self.output_layer = \
            self.build_output_layer(shape_list(self._embedding)[-1])
示例#2
0
    def dynamic_decode(self, template_input_pack,
                       encoder_decoder_attention_bias, segment_ids, offsets,
                       bos_id, eos_id):
        """
            this function is called on in test mode, without the target input.
        """
        with tf.variable_scope(self.variable_scope, reuse=True):
            template = template_input_pack['templates']
            template_word_embeds = tf.nn.embedding_lookup(
                self._embedding, template)
            batch_size = tf.shape(template)[0]
            template_length = shape_list(template)[1]
            channels = shape_list(template_word_embeds)[2]
            template_pos_embeds = self.position_embedder(
                template_length, channels, template_input_pack['segment_ids'],
                template_input_pack['offsets'])
            template_inputs = template_word_embeds + template_pos_embeds

            # batch_size = tf.shape(template_inputs)[0]
            beam_width = self._hparams.beam_width
            maximum_decode_length = self.hparams.maximum_decode_length
            start_tokens = tf.cast(tf.fill([batch_size], bos_id),
                                   dtype=tf.int32)
            if beam_width <= 1:
                sampled_ids, log_probs = self.greedy_decode(
                    self.prepare_tokens_to_embeds,
                    start_tokens,
                    eos_id, #self._hparams.eos_idx,
                    decode_length=maximum_decode_length,
                    memory=template_inputs,
                    encoder_decoder_attention_bias=\
                        encoder_decoder_attention_bias,
                    segment_ids=segment_ids,
                    offsets=offsets,
                )
            else:
                sampled_ids, log_probs = self.beam_decode(
                    self.prepare_tokens_to_embeds,
                    start_tokens,
                    eos_id, #self._hparams.eos_idx,
                    beam_width=beam_width,
                    decode_length=maximum_decode_length,
                    memory=template_inputs,
                    encoder_decoder_attention_bias=\
                        encoder_decoder_attention_bias,
                    segment_ids=segment_ids,
                    offsets=offsets
                )
            predictions = {'sampled_ids': sampled_ids, 'log_probs': log_probs}
        return predictions
示例#3
0
    def _build(self, decoder_input_pack, template_input_pack,
               encoder_decoder_attention_bias, args):
        """
            this function is called on training generally.
            Args:
                targets: [bath_size, target_length], generally begins with [bos] token
                template_input: [batch_size, source_length, channels]
                segment_ids: [batch_size, source_length], which segment this word belongs to
            outputs:
                logits: [batch_size, target_length, vocab_size]
                preds: [batch_size, target_length]
        """
        input = decoder_input_pack['text_ids'][:, :-1]
        decoder_self_attention_bias = (
            attentions.attention_bias_lower_triangle(shape_list(input)[1]))
        input_word_embeds = tf.nn.embedding_lookup(self._embedding, input)
        if self._hparams.multiply_embedding_mode == 'sqrt_depth':
            input_word_embeds = input_word_embeds * \
                (self._embedding.shape.as_list()[-1]**0.5)
        length = shape_list(input_word_embeds)[1]
        channels = shape_list(input_word_embeds)[2]
        input_pos_embeds = self.position_embedder(
            length, channels, decoder_input_pack['segment_ids'][:, :-1],
            decoder_input_pack['offsets'][:, :-1])
        inputs = input_word_embeds + input_pos_embeds

        template = template_input_pack['templates']
        template_word_embeds = tf.nn.embedding_lookup(self._embedding,
                                                      template)
        template_length = shape_list(template)[1]
        template_pos_embeds = self.position_embedder(
            template_length, channels, template_input_pack['segment_ids'],
            template_input_pack['offsets'])
        template_inputs = template_word_embeds + template_pos_embeds
        self.decoder_output = self._self_attention_stack(
            inputs,
            template_inputs,
            decoder_self_attention_bias=decoder_self_attention_bias,
        )

        logits = self.output_layer(self.decoder_output)
        preds = tf.to_int32(tf.argmax(logits, axis=-1))

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return logits, preds
示例#4
0
 def step(self, time, inputs, state, name=None):
     cell_outputs, cell_state = self._cell(inputs, state)
     logits = self._output_layer(
         cell_outputs)  # turn cell outputs into logits for for each vocab
     sample_ids = self._helper.sample(  # turn logits into ids
         time=time,
         outputs=logits,
         state=cell_state)
     (finished, next_inputs_word_embeds,
      next_state) = self._helper.next_inputs(
          time=time,
          outputs=logits,
          state=cell_state,
          sample_ids=sample_ids)  # look up in embedding -> next_inputs
     batch_size, channels = shape_list(next_inputs_word_embeds)
     next_input_pos_embeds = self.position_embedder(
         length=1,
         channels=channels,
         segment_ids=tf.cast(tf.fill([batch_size, 1],
                                     self.current_segment_id),
                             dtype=tf.int64),
         offsets=tf.cast(tf.fill([batch_size, 1], time), dtype=tf.int64))
     next_input_pos_embeds = tf.reshape(next_input_pos_embeds,
                                        [batch_size, channels])
     next_inputs = next_inputs_word_embeds + next_input_pos_embeds
     outputs = BasicRNNDecoderOutput(logits, sample_ids, cell_outputs)
     return (outputs, next_state, next_inputs, finished)
示例#5
0
    def __init__(self, inputs, sequence_length, time_major=False, name=None):
        """Initializer.

        Args:
          inputs: A (structure of) input tensors.
          sequence_length: An int32 vector tensor.
          time_major: Python bool.  Whether the tensors in `inputs` are time major.
            If `False` (default), they are assumed to be batch major.
          name: Name scope for any created operations.

        Raises:
          ValueError: if `sequence_length` is not a 1D tensor.
        """
        with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
            inputs = ops.convert_to_tensor(inputs, name="inputs")
            self._inputs = inputs
            if not time_major:
                inputs = nest.map_structure(_transpose_batch_time, inputs)

            self._input_tas = nest.map_structure(_unstack_ta, inputs)
            self._sequence_length = ops.convert_to_tensor(
                sequence_length, name="sequence_length")
            if self._sequence_length.get_shape().ndims != 1:
                raise ValueError(
                    "Expected sequence_length to be a vector, but received shape: %s" %
                    self._sequence_length.get_shape())

            self._zero_inputs = nest.map_structure(
                lambda inp: array_ops.zeros_like(inp[0, :]), inputs)

            self._batch_size = shape_list(sequence_length)[0]
    def _symbols_to_logits_fn(self, embedding_fn, max_length):
        """Returns a function that accepts the decoded tokens and related
        decoding status, and returns the logits of next token.
        """
        channels = shape_list(self._embedding)[-1]
        timing_signal = self.position_embedder(max_length, channels)

        def _impl(ids, step, cache):
            """The function is called in dynamic decoding.

            `ids` should be next_id of shape `[batch_size, decoded_lenth]`

            Returned logits is of shape `[batch_size, 1]`
            """
            ids = ids[:, -1:]
            inputs = embedding_fn(ids)
            # Multiply embedding by sqrt of its dimention
            inputs *= self._embedding.shape.as_list()[-1]**0.5
            inputs += timing_signal[:, step:step+1]
            outputs = self._self_attention_stack(
                inputs,
                memory=cache['memory'],
                cache=cache,
            )
            logits = self.output_layer(outputs)
            logits = tf.squeeze(logits, axis=[1])
            return logits, cache

        return _impl
示例#7
0
    def _symbols_to_logits_fn(self, embedding_fn, max_length, segment_ids,
                              offsets):
        channels = shape_list(self._embedding)[-1]
        timing_signal = self.position_embedder(max_length, channels,
                                               segment_ids, offsets)
        """ the function is normally called in dynamic decoding mode.
                the ids should be `next_id` with the shape [batch_size, 1]
            the returned logits is [batch_size, 1]
        """
        def _impl(ids, step, cache):
            ids = ids[:, -1:]
            decoder_self_attention_bias = (
                attentions.attention_bias_lower_triangle(shape_list(ids)[1]))
            inputs = embedding_fn(ids)
            if self._hparams.multiply_embedding_mode == 'sqrt_depth':
                inputs *= self._embedding.shape.as_list()[-1]**0.5
            else:
                assert NotImplementedError
            inputs += timing_signal[:, step:step + 1]

            outputs = self._self_attention_stack(
                inputs,
                template_input=cache['memory'],
                cache=cache,
                decoder_self_attention_bias=decoder_self_attention_bias,
            )
            logits = self.output_layer(outputs)
            logits = tf.squeeze(logits, axis=[1])

            return logits, cache

        return _impl
示例#8
0
def _make_output_layer(output_layer, vocab_size, output_layer_bias,
                       variable_scope):
    """Makes a decoder output layer.
    """
    _vocab_size = vocab_size
    if is_callable(output_layer):
        _output_layer = output_layer
    elif tf.contrib.framework.is_tensor(output_layer):
        _vocab_size = shape_list(output_layer)[1]
        _output_layer = _make_output_layer_from_tensor(output_layer,
                                                       _vocab_size,
                                                       output_layer_bias,
                                                       variable_scope)
    elif output_layer is None:
        if _vocab_size is None:
            raise ValueError(
                "Either `output_layer` or `vocab_size` must be provided. "
                "Set `output_layer=tf.identity` if no output layer is "
                "wanted.")
        with tf.variable_scope(variable_scope):
            # pylint: disable=redefined-variable-type
            _output_layer = tf.layers.Dense(units=_vocab_size,
                                            use_bias=output_layer_bias)
    else:
        raise ValueError(
            "output_layer should be a callable layer, a tensor, or None. "
            "Unsupported type: ", type(output_layer))

    return _output_layer, _vocab_size
示例#9
0
 def _compute_embeddings(self, positions):
     inv_timescales = self.inv_timescales
     scaled_time = tf.reshape(tf.cast(positions, inv_timescales.dtype),
                              (-1, 1)) * tf.expand_dims(inv_timescales, 0)
     signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
     signal = tf.pad(signal, [[0, 0], [0, tf.mod(self._dim, 2)]])
     signal = tf.reshape(signal, shape_list(positions) + [self._dim])
     return signal
示例#10
0
 def _outputs_to_logits(outputs):
     shape = shape_list(outputs)
     outputs = tf.reshape(outputs, [-1, dim])
     logits = tf.matmul(outputs, self._embedding, transpose_b=True)
     if affine_bias is not None:
         logits += affine_bias
     logits = tf.reshape(logits, shape[:-1] + [self._vocab_size])
     return logits
示例#11
0
    def __init__(self, embedding, hparams=None):
        ModuleBase.__init__(self, hparams)

        with tf.variable_scope(self.variable_scope):
            if self._hparams.initializer:
                tf.get_variable_scope().set_initializer(
                    layers.get_initializer(self._hparams.initializer))

            if self._hparams.position_embedder_type == 'sinusoids':
                self.position_embedder = SinusoidsPositionEmbedder(
                    self._hparams.position_embedder_hparams)
            else:
                self.position_embedder = PositionEmbedder(
                    position_size=self._hparams.position_size,
                    hparams=self._hparams.position_embedder_hparams)

            self._embedding = embedding
            self._vocab_size = self._embedding.get_shape().as_list()[0]

            self.output_layer = \
                self._build_output_layer(shape_list(self._embedding)[-1])

            self.multihead_attentions = {'self_att': [], 'encdec_att': []}
            self.poswise_networks = []
            for i in range(self._hparams.num_blocks):
                layer_name = 'layer_{}'.format(i)
                with tf.variable_scope(layer_name):
                    with tf.variable_scope("self_attention"):
                        multihead_attention = MultiheadAttentionEncoder(
                            self._hparams.multihead_attention)
                        self.multihead_attentions['self_att'].append(
                            multihead_attention)
                    # pylint: disable=protected-access
                    if self._hparams.dim != \
                        multihead_attention._hparams.output_dim:
                        raise ValueError('The output dimenstion of '
                                         'MultiheadEncoder should be equal '
                                         'to the dim of TransformerDecoder')

                    with tf.variable_scope('encdec_attention'):
                        multihead_attention = MultiheadAttentionEncoder(
                            self._hparams.multihead_attention)
                        self.multihead_attentions['encdec_att'].append(
                            multihead_attention)
                    if self._hparams.dim != \
                        multihead_attention._hparams.output_dim:
                        raise ValueError('The output dimenstion of '
                                         'MultiheadEncoder should be equal '
                                         'to the dim of TransformerDecoder')

                    poswise_network = FeedForwardNetwork(
                        hparams=self._hparams['poswise_feedforward'])
                    if self._hparams.dim != \
                        poswise_network._hparams.layers[-1]['kwargs']['units']:
                        raise ValueError('The output dimenstion of '
                                         'FeedForwardNetwork should be equal '
                                         'to the dim of TransformerDecoder')
                    self.poswise_networks.append(poswise_network)
示例#12
0
 def _expand_to_beam_width(self, tensor, beam_width):
     """
     :param tensor: [batch_size, max_len]
     :param beam_width:
     :return: [batch_size*beam_width, max_len]
     """
     batch_size = shape_list(tensor)[0]
     expanded = tf.tile(tf.expand_dims(tensor, axis=1), [1, beam_width, 1])
     return tf.reshape(expanded, [batch_size * beam_width, -1])
示例#13
0
 def _split_heads(self, x):
     """Split channels (dimension 2) into multiple heads,
         becomes dimension 1).
     Must ensure `x.shape[-1]` can be deviced by num_heads
     """
     depth = shape_list(x)[-1]
     splitted_x = tf.reshape(x, [tf.shape(x)[0], tf.shape(x)[1], \
         self._hparams.num_heads, depth // self._hparams.num_heads])
     return tf.transpose(splitted_x, [0, 2, 1, 3])
示例#14
0
 def _outputs_to_logits(outputs):
     shape = shape_list(outputs)
     dim = shape[-1]
     outputs = tf.reshape(outputs, [-1, dim])
     logits = tf.matmul(outputs, output_layer_tensor)
     if affine_bias is not None:
         logits += affine_bias
     logits = tf.reshape(logits, shape[:-1] + [vocab_size])
     return logits
示例#15
0
 def _combine_heads(self, x):
     """
     Args:
         x: A Tensor of shape `[batch, num_heads, seq_len, dim]`
     Returns:
         A Tensor of shape `[batch, seq_len, num_heads * dim]`
     """
     t = tf.transpose(x, [0, 2, 1, 3]) #[batch, seq_len, num_heads, dim]
     num_heads, dim = shape_list(t)[-2:]
     assert num_heads == self._hparams.num_heads
     return tf.reshape(t, [tf.shape(t)[0], tf.shape(t)[1], num_heads*dim])
示例#16
0
    def _build(self, decoder_input, encoder_output, \
        encoder_decoder_attention_bias, mode=None):
        """
            this function is called on training generally.
            Args:
                targets: [bath_size, target_length], generally begins with [bos] token
                encoder_output: [batch_size, source_length, channels]
            outputs:
                logits: [batch_size, target_length, vocab_size]
                preds: [batch_size, target_length]
        """
        logits = None
        decoder_self_attention_bias = (
            attentions.attention_bias_lower_triangle(
                shape_list(decoder_input)[1]))
        target_inputs = tf.nn.embedding_lookup(self._embedding, decoder_input)
        if self._hparams.multiply_embedding_mode == 'sqrt_depth':
            target_inputs = target_inputs * \
                (self._embedding.shape.as_list()[-1]**0.5)
        lengths = shape_list(target_inputs)[1]
        channels = shape_list(target_inputs)[2]
        pos_embeds = self.position_embedder(lengths, channels)
        inputs = target_inputs + pos_embeds
        self.decoder_output = self._self_attention_stack(
            inputs,
            encoder_output,
            decoder_self_attention_bias=decoder_self_attention_bias,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            cache=None,
            mode=None,
        )

        logits = self.output_layer(self.decoder_output)
        preds = tf.to_int32(tf.argmax(logits, axis=-1))

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return logits, preds
def _merge_beam_dim(tensor):
    """Reshapes first two dimensions in to single dimension.

    Args:
        tensor: Tensor to reshape of shape [A, B, ...]

    Returns:
        Reshaped tensor of shape [A*B, ...]
    """
    shape = shape_list(tensor)
    shape[0] *= shape[1]  # batch -> batch * beam_size
    shape.pop(1)  # Remove beam dim
    return tf.reshape(tensor, shape)
def _unmerge_beam_dim(tensor, batch_size, beam_size):
    """Reshapes first dimension back to [batch_size, beam_size].

    Args:
        tensor: Tensor to reshape of shape [batch_size*beam_size, ...]
        batch_size: Tensor, original batch size.
        beam_size: int, original beam size.

    Returns:
        Reshaped tensor of shape [batch_size, beam_size, ...]
    """
    shape = shape_list(tensor)
    new_shape = [batch_size] + [beam_size] + shape[1:]
    return tf.reshape(tensor, new_shape)
示例#19
0
    def __init__(self, embedding, hparams=None):
        ModuleBase.__init__(self, hparams)

        with tf.variable_scope(self.variable_scope):
            if self._hparams.initializer:
                tf.get_variable_scope().set_initializer( \
                    layers.get_initializer(self._hparams.initializer))

            self.position_embedder = \
                SinusoidsPositionEmbedder(
                    self._hparams.position_embedder_hparams)

            self._embedding = embedding
            self._vocab_size = self._embedding.get_shape().as_list()[0]

        self.output_layer = \
            self._build_output_layer(shape_list(self._embedding)[-1])
示例#20
0
    def __init__(self, embedding, start_tokens, end_token):
        """Initializer.

        Args:
          embedding: A callable or the `params` argument for `embedding_lookup`.
            If a callable, it can take a vector tensor of `ids` (argmax ids),
            or take two arguments (`ids`, `times`), where `ids` is a vector
            tensor of argmax ids, and `times` is a vector tensor of current
            time steps (i.e., position ids). The latter case can be used when
            attr:`embedding` is a combination of word embedding and position
            embedding.
            The returned tensor will be returned by :meth:`next_inputs`.
          start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
          end_token: `int32` scalar, the token that marks end of decoding.

        Raises:
          ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
            scalar.
        """
        if callable(embedding):
            self._embedding_fn = embedding
        else:
            self._embedding_fn = (
                lambda ids: embedding_ops.embedding_lookup(embedding, ids))

        self._start_tokens = ops.convert_to_tensor(
            start_tokens, dtype=dtypes.int32, name="start_tokens")
        self._end_token = ops.convert_to_tensor(
            end_token, dtype=dtypes.int32, name="end_token")
        if self._start_tokens.get_shape().ndims != 1:
            raise ValueError("start_tokens must be a vector")
        self._batch_size = shape_list(start_tokens)[0]
        if self._end_token.get_shape().ndims != 0:
            raise ValueError("end_token must be a scalar")

        self._embedding_args_cnt = len(get_args(self._embedding_fn))
        if self._embedding_args_cnt == 1:
            self._start_inputs = self._embedding_fn(self._start_tokens)
        elif self._embedding_args_cnt == 2:
            # Position index is 0 in the beginning
            times = tf.zeros([self._batch_size], dtype=tf.int32)
            self._start_inputs = self._embedding_fn(self._start_tokens, times)
        else:
            raise ValueError('`embedding` should expect 1 or 2 arguments.')
示例#21
0
    def _input_ids_to_outputs(self, input_ids, step, cache):
        """The function is called in beam-search decoding.

        `inputs` should be of shape `[batch_size]`.

        Returns outputs (i.e. logits) of shape `[batch_size, vocab_size]`
        and updated cache.
        """
        _batch_size = shape_list(input_ids)[0]
        times = tf.ones([_batch_size], dtype=tf.int32) * step
        inputs = self.embedding(input_ids, times)

        outputs = self._self_attention_stack(
            tf.expand_dims(inputs, axis=1),
            memory=cache.get('memory'),
            cache=cache,
        )
        outputs = self._output_layer(outputs)
        outputs = tf.squeeze(outputs, axis=[1])
        return outputs, cache
示例#22
0
        def _impl(ids, step, cache):
            ids = ids[:, -1:]
            decoder_self_attention_bias = (
                attentions.attention_bias_lower_triangle(shape_list(ids)[1]))
            inputs = embedding_fn(ids)
            if self._hparams.multiply_embedding_mode == 'sqrt_depth':
                inputs *= self._embedding.shape.as_list()[-1]**0.5
            else:
                assert NotImplementedError
            inputs += timing_signal[:, step:step + 1]

            outputs = self._self_attention_stack(
                inputs,
                template_input=cache['memory'],
                cache=cache,
                decoder_self_attention_bias=decoder_self_attention_bias,
            )
            logits = self.output_layer(outputs)
            logits = tf.squeeze(logits, axis=[1])

            return logits, cache
示例#23
0
    def _build(self, inputs, sequence_length, mode=None):
        """Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`,
                containing the word embeddings of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Input tokens
                beyond respective sequence lengths are masked out
                automatically.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """
        # Multiply input embedding with the sqrt of its dimension for
        # normalization
        inputs = inputs * self._hparams.dim**0.5

        inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)

        _, lengths, _ = shape_list(inputs)

        inputs_padding = 1 - tf.sequence_mask(
            sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
        ignore_padding = attn.attention_bias_ignore_padding(inputs_padding)
        encoder_self_attention_bias = ignore_padding

        pos_embeds = self.position_embedder(lengths, self._hparams.dim)
        input_embedding = inputs + pos_embeds

        x = tf.layers.dropout(input_embedding,
                              rate=self._hparams.embedding_dropout,
                              training=is_train_mode(mode))
        pad_remover = utils.transformer_utils.PadRemover(inputs_padding)

        for i in range(self._hparams.num_blocks):
            with tf.variable_scope("layer_{}".format(i)):
                with tf.variable_scope('self_attention'):
                    selfatt_output = attn.multihead_attention(
                        queries=layers.layer_normalize(x),
                        memory=None,
                        memory_attention_bias=encoder_self_attention_bias,
                        num_heads=self._hparams.num_heads,
                        dropout_rate=self._hparams.attention_dropout,
                        num_units=self._hparams.dim,
                        scope='multihead_attention')
                    x = x + tf.layers.dropout(
                        selfatt_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode),
                    )

                poswise_network = FeedForwardNetwork(
                    hparams=self._hparams['poswise_feedforward'])
                with tf.variable_scope(poswise_network.variable_scope):
                    y = layers.layer_normalize(x)
                    original_shape = shape_list(y)
                    y = tf.reshape(y, [-1, self._hparams.dim])
                    y = tf.expand_dims(pad_remover.remove(y), axis=0)
                    # [1, batch_size*seq_length, hidden_dim]
                    sub_output = tf.layers.dropout(
                        poswise_network(y),
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode))
                    sub_output = tf.reshape(pad_remover.restore(tf.squeeze(\
                        sub_output, axis=0)), original_shape \
                    )
                    x = x + sub_output

        encoder_output = layers.layer_normalize(x)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return encoder_output
示例#24
0
    def _build(self,    # pylint: disable=arguments-differ
               memory,
               memory_sequence_length=None,
               memory_attention_bias=None,
               inputs=None,
               sequence_length=None,
               decoding_strategy='train_greedy',
               beam_width=1,
               alpha=0,
               start_tokens=None,
               end_token=None,
               max_decoding_length=None,
               mode=None):
        """Performs decoding.

        The decoder supports 4 decoding strategies. For the first 3 strategies,
        set :attr:`decoding_strategy` to the respective string.

        - **"train_greedy"**: decoding in teacher-forcing fashion \
          (i.e., feeding \
          ground truth to decode the next step), and for each step sample \
          is obtained by taking the `argmax` of logits. \
          Argument :attr:`inputs` is required for this strategy. \
          :attr:`sequence_length` is optional.
        - **"infer_greedy"**: decoding in inference fashion (i.e., feeding \
          `generated` sample to decode the next step), and for each
          step sample is obtained by taking the `argmax` of logits.\
          Arguments :attr:`(start_tokens, end_token)` are \
          required for this strategy, and argument \
          :attr:`max_decoding_length` is optional.
        - **"infer_sample"**: decoding in inference fashion, and for each step\
          sample is obtained by `random sampling` from the logits.
          Arguments :attr:`(start_tokens, end_token)` are \
          required for this strategy, and argument \
          :attr:`max_decoding_length` is optional.
        - **Beam Search**: set :attr:`beam_width` to > 1 to use beam search \
          decoding.\
          Arguments :attr:`(start_tokens, end_token)` are \
          required, and argument \
          :attr:`max_decoding_length` is optional.

        Args:
            memory: The memory to attend, e.g., the output of an RNN encoder.
                A Tensor of shape `[batch_size, memory_max_time, dim]`.
            memory_sequence_length (optional): A Tensor of shape `[batch_size]`
                containing the sequence lengths for the batch entries in
                memory. Used to create attention bias of
                :attr:`memory_attention_bias` is not given. Ignored if
                `memory_attention_bias` is provided.
            memory_attention_bias (optional): A Tensor of shape
                `[batch_size, num_heads, memory_max_time, dim]`.
                An attention bias typically sets the value of a padding
                position to a large negative value for masking. If not given,
                :attr:`memory_sequence_length` is used to automatically
                create an attention bias.
            inputs (optional): Input tensor for teacher forcing decoding, of
                shape `[batch_size, target_max_time, emb_dim]` containing the
                target sequence word embeddings.
                Used when :attr:`decoding_strategy` is set to "train_greedy".
            sequence_length (optional): A Tensor of shape `[batch_size]`,
                containing the sequence length of :attr:`inputs`.
                Tokens beyond the respective sequence length are masked out.
                Used when :attr:`decoding_strategy` is set to
                "train_greedy".
            decoding_strategy (str): A string specifying the decoding
                strategy, including "train_greedy", "infer_greedy",
                "infer_sample".
                Different arguments are required based on the
                strategy. See above for details. Ignored if
                :attr:`beam_width` > 1.
            beam_width (int): Set to > 1 to use beam search.
            alpha (float): Length penalty coefficient.
                Refer to https://arxiv.org/abs/1609.08144
                for more details.
            tart_tokens (optional): An int Tensor of shape `[batch_size]`,
                containing the start tokens.
                Used when `decoding_strategy` = "infer_greedy" or
                "infer_sample", or `beam_width` > 1.
            end_token (optional): An int 0D Tensor, the token that marks end
                of decoding.
                Used when `decoding_strategy` = "infer_greedy" or
                "infer_sample", or `beam_width` > 1.
            max_decoding_length (optional): An int scalar Tensor indicating
                the maximum allowed number of decoding steps.
                If `None` (default), use "max_decoding_length" defined in
                :attr:`hparams`. Ignored in "train_greedy" decoding.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. Controls dropout mode.
                If `None` (default), :func:`texar.global_mode`
                is used.

        Returns:

            - For **"train_greedy"** decoding, returns an instance of \
            :class:`~texar.modules.TransformerDecoderOutput` which contains\
            `sample_id` and `logits`.

            - For **"infer_greedy"** and **"infer_sample"** decoding, returns\
            a tuple `(outputs, sequence_lengths)`, where `outputs` is an \
            instance of :class:`~texar.modules.TransformerDecoderOutput` as\
            in "train_greedy", and `sequence_lengths` is a Tensor of shape\
            `[batch_size]` containing the length of each sample.

            - For **beam_search** decoding, returns a `dict` containing keys\
            "sample_id" and "log_prob".

                - **"sample_id"** is an int Tensor of shape \
                `[batch_size, max_time, beam_width]` containing generated\
                token indexes. `sample_id[:,:,0]` is the highest-probable \
                sample.
                - **"log_porb"** is a float Tensor of shape \
                `[batch_size, beam_width]` containing the log probability \
                of each sequence sample.
        """
        if memory_attention_bias is None:
            if memory_sequence_length is None:
                raise ValueError(
                    "`memory_sequence_length` is required if "
                    "`memory_attention_bias` is not given.")

            #enc_padding = 1 - mask_sequences(tf.ones_like(memory),
            #                                 memory_sequence_length,
            #                                 tensor_rank=3)[:, :, 0]
            enc_padding = 1 - tf.sequence_mask(
                memory_sequence_length, tf.shape(memory)[1], dtype=tf.float32)
            memory_attention_bias = attn.attention_bias_ignore_padding(
                enc_padding)

        if beam_width <= 1 and decoding_strategy == 'train_greedy':
            if sequence_length is not None:
                inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)

            decoder_self_attention_bias = (
                attn.attention_bias_lower_triangle(
                    shape_list(inputs)[1]))
            target_inputs = inputs * self._hparams.dim**0.5

            _, lengths, channels = shape_list(target_inputs)
            pos_embeds = self.position_embedder(lengths, channels)

            inputs = target_inputs + pos_embeds

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias=decoder_self_attention_bias,
                memory_attention_bias=memory_attention_bias,
                cache=None,
                mode=mode)
            logits = self.output_layer(decoder_output)
            preds = tf.to_int32(tf.argmax(logits, axis=-1))
            output = TransformerDecoderOutput(
                logits=logits,
                sample_id=preds
            )
            rets = output

        else: # Inference decoding

            if max_decoding_length is None:
                max_decoding_length = self._hparams.max_decoding_length

            if beam_width <= 1:
                logits, preds, sequence_length = self._infer_decoding(
                    self._prepare_tokens_to_embeds,
                    start_tokens,
                    end_token,
                    decode_length=max_decoding_length,
                    memory=memory,
                    memory_attention_bias=memory_attention_bias,
                    decoding_strategy=decoding_strategy,
                )
                output = TransformerDecoderOutput(
                    logits=logits,
                    sample_id=preds)
                rets = output, sequence_length
            else:
                # The output format is different when running beam search
                sample_id, log_prob = self._beam_decode(
                    self._prepare_tokens_to_embeds,
                    start_tokens,
                    end_token,
                    beam_width=beam_width,
                    alpha=alpha,
                    decode_length=max_decoding_length,
                    memory=memory,
                    memory_attention_bias=memory_attention_bias,
                )
                predictions = {
                    'sample_id':sample_id,
                    'log_prob': log_prob
                }
                rets = predictions

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return rets
示例#25
0
    def _build(self, inputs, sequence_length, mode=None):
        """Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`,
                containing the word embeddings of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Input tokens
                beyond respective sequence lengths are masked out
                automatically.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """
        # Multiply input embedding with the sqrt of its dimension for
        # normalization
        if not self._hparams.use_bert_config:
            inputs = inputs * self._hparams.dim**0.5
            inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)
        _, lengths, _ = shape_list(inputs)

        inputs_padding = 1 - tf.sequence_mask(
            sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
        if self._hparams.use_bert_config:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding, bias_value=-1e4)
        else:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding)

        encoder_self_attention_bias = ignore_padding

        positions = tf.expand_dims(tf.range(lengths, dtype=tf.int32), 0)
        pos_embeds = self.position_embedder(positions)

        input_embedding = inputs + pos_embeds

        if self._hparams.use_bert_config:
            x = layers.layer_normalize(input_embedding)
            x = tf.layers.dropout(x,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))
        else:
            x = tf.layers.dropout(input_embedding,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))

        # Just to keep consistent with BERT, actually makes no difference
        if self._hparams.use_bert_config:
            pad_remover = None
        else:
            pad_remover = utils.transformer_utils.PadRemover(inputs_padding)

        for i in range(self._hparams.num_blocks):
            with tf.variable_scope("layer_{}".format(i)):
                multihead_attention = self.multihead_attention_list[i]
                # trivial difference between BERT and original Transformer
                if self._hparams.use_bert_config:
                    _queries_input = x
                else:
                    _queries_input = layers.layer_normalize(x)

                attention_output = multihead_attention(
                    queries=_queries_input,
                    memory=_queries_input,
                    memory_attention_bias=encoder_self_attention_bias,
                    mode=mode,
                )
                attention_output = tf.layers.dropout(
                    attention_output,
                    rate=self._hparams.residual_dropout,
                    training=is_train_mode(mode),
                )
                x = x + attention_output
                with tf.variable_scope('output'):
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)
                        y = x
                    else:
                        y = layers.layer_normalize(x)
                poswise_network = self.poswise_networks[i]
                with tf.variable_scope(poswise_network.variable_scope):
                    original_shape = shape_list(y)
                    y = tf.reshape(y, [-1, self._hparams.dim])
                    if pad_remover:
                        y = tf.expand_dims(pad_remover.remove(y), axis=0)
                        # [1, batch_size*seq_length, hidden_dim]
                    layer_output = poswise_network(y, mode=mode)
                    sub_output = tf.layers.dropout(
                        layer_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode)
                    )
                    if pad_remover:
                        sub_output = tf.reshape(pad_remover.restore(tf.squeeze(\
                            sub_output, axis=0)), original_shape \
                        )
                    else:
                        sub_output = tf.reshape(sub_output, original_shape)

                    x = x + sub_output
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)

        if not self._hparams.use_bert_config:
            x = layers.layer_normalize(x)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return x
    def _build(
            self,  # pylint: disable=arguments-differ, too-many-statements
            decoding_strategy='train_greedy',
            inputs=None,
            adjs=None,
            memory=None,
            memory_sequence_length=None,
            memory_attention_bias=None,
            beam_width=None,
            length_penalty=0.,
            start_tokens=None,
            end_token=None,
            context=None,
            context_sequence_length=None,
            softmax_temperature=None,
            max_decoding_length=None,
            impute_finished=False,
            embedding=None,
            helper=None,
            mode=None):
        """Performs decoding.

        See 'Texar.modules.decoders.transformer_decoders.TransformerDecoder' for details

        adjs: A 3D Tensor of shape `[batch_size, max_time, max_time]`,
                containing the adjacency matrices of input sequences
        """
        # Get adjacency masks from adjs
        self.adj_masks = 1 - tf.cast(tf.equal(adjs, 0), dtype=tf.float32)

        if memory is not None:
            if memory_attention_bias is None:
                if memory_sequence_length is None:
                    raise ValueError("`memory_sequence_length` is required if "
                                     "`memory_attention_bias` is not given.")

                enc_padding = 1 - tf.sequence_mask(memory_sequence_length,
                                                   shape_list(memory)[1],
                                                   dtype=tf.float32)
                memory_attention_bias = attn.attention_bias_ignore_padding(
                    enc_padding)

        # record the context, which will be used in step function
        # for dynamic_decode
        if context is not None:
            start_tokens = context[:, 0]
            self.context = context[:, 1:]
            self.context_sequence_length = context_sequence_length - 1
        else:
            self.context = None

        self.embedding = embedding

        if helper is None and beam_width is None and \
                decoding_strategy == 'train_greedy':  # Teacher-forcing

            decoder_self_attention_bias = (attn.attention_bias_lower_triangle(
                shape_list(inputs)[1]))

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias=decoder_self_attention_bias,
                memory_attention_bias=memory_attention_bias,
                cache=None,
                mode=mode)
            logits = self._output_layer(decoder_output)
            preds = tf.to_int32(tf.argmax(logits, axis=-1))
            rets = TransformerDecoderOutput(logits=logits, sample_id=preds)

        else:
            if max_decoding_length is None:
                max_decoding_length = self._hparams.max_decoding_length
            self.max_decoding_length = max_decoding_length
            if beam_width is None:  # Inference-like decoding
                # Prepare helper
                if helper is None:
                    if decoding_strategy == "infer_greedy":
                        helper = tx_helper.GreedyEmbeddingHelper(
                            embedding, start_tokens, end_token)
                    elif decoding_strategy == "infer_sample":
                        helper = tx_helper.SampleEmbeddingHelper(
                            embedding, start_tokens, end_token,
                            softmax_temperature)
                    else:
                        raise ValueError(
                            "Unknown decoding strategy: {}".format(
                                decoding_strategy))
                self._helper = helper

                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=False)
                if context is not None:
                    self.context = tf.pad(self.context, [[
                        0, 0
                    ], [0, max_decoding_length - shape_list(self.context)[1]]])

                outputs, _, sequence_lengths = dynamic_decode(
                    decoder=self,
                    impute_finished=impute_finished,
                    maximum_iterations=max_decoding_length,
                    output_time_major=False,
                    scope=self.variable_scope)

                if context is not None:
                    # Here the length of sample_id will be larger than that
                    # of logit by 1, because there will be a additional
                    # start_token in the returned sample_id.
                    # the start_id should be the first token of the
                    # given context
                    outputs = TransformerDecoderOutput(
                        logits=outputs.logits,
                        sample_id=tf.concat([
                            tf.expand_dims(start_tokens, 1), outputs.sample_id
                        ],
                                            axis=1))
                    sequence_lengths = sequence_lengths + 1
                rets = outputs, sequence_lengths

            else:  # Beam-search decoding
                # Ignore `decoding_strategy`; Assume `helper` is not set
                if helper is not None:
                    raise ValueError("Must not set 'beam_width' and 'helper' "
                                     "simultaneously.")
                _batch_size = shape_list(start_tokens)[0]
                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=True,
                                               batch_size=_batch_size)

                # The output format is different when running beam search
                sample_id, log_prob = self._beam_decode(
                    start_tokens,
                    end_token,
                    beam_width=beam_width,
                    length_penalty=length_penalty,
                    decode_length=max_decoding_length,
                )
                rets = {'sample_id': sample_id, 'log_prob': log_prob}

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return rets
    def _build(self,
               inputs,
               memory,
               sequence_length,
               memory_sequence_length,
               adjs,
               encoder_output,
               mode=None):
        """Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`,
                containing the embedding of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`. The input embedding is typically an aggregation
                of word embedding and position embedding.
            memory: A 3D Tensor of shape `[batch_size, memory_max_time, dim]`, 
                containing the embedding of memory sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`. The input embedding is typically an aggregation
                of word embedding and position embedding.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Input tokens
                beyond respective sequence lengths are masked out
                automatically.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Memory tokens
                beyond respective sequence lengths are masked out
                automatically.
            adjs: A 3D Tensor of shape `[batch_size, max_time, max_time]`,
                containing the adjacency matrices of input sequences
            encoder_output: bool. True: return encoder-like embeddings. False: return CrossGraphTransformerDecoderOutput. 
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """
        # Get adjacency masks from adjs
        adj_masks = 1 - tf.cast(tf.equal(adjs, 0), dtype=tf.float32)

        # Multiply input embedding with the sqrt of its dimension for
        # normalization
        inputs_padding = 1 - tf.sequence_mask(
            sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
        if self._hparams.use_bert_config:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding, bias_value=-1e4)
        else:
            ignore_padding = attn.attention_bias_ignore_padding(inputs_padding)
        encoder_self_attention_bias = ignore_padding

        input_embedding = inputs  # shape (batch_size, max_time, dim)

        if self._hparams.use_bert_config:
            x = layers.layer_normalize(input_embedding)
            x = tf.layers.dropout(x,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))
        else:
            x = tf.layers.dropout(input_embedding,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))

        # Just to keep consistent with BERT, actually makes no difference
        if self._hparams.use_bert_config:
            pad_remover = None
        else:
            pad_remover = utils.transformer_utils.PadRemover(inputs_padding)

        for i in range(self._hparams.num_blocks):
            with tf.variable_scope("layer_{}".format(i)):
                graph_multihead_attention = self.graph_multihead_attention_list[
                    i]

                # trivial difference between BERT and original Transformer
                if self._hparams.use_bert_config:
                    _queries_input = x
                else:
                    _queries_input = layers.layer_normalize(x)

                attention_output = graph_multihead_attention(
                    queries=_queries_input,
                    memory=memory,
                    adj_masks=adj_masks,
                    memory_attention_bias=encoder_self_attention_bias,
                    mode=mode,
                )
                attention_output = tf.layers.dropout(
                    attention_output,
                    rate=self._hparams.residual_dropout,
                    training=is_train_mode(mode),
                )
                # attention_output: weighted sum of V of memory with weights determined by querying keys of memory
                x = x + attention_output
                with tf.variable_scope('output'):
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)
                        y = x
                    else:
                        y = layers.layer_normalize(x)

                poswise_network = self.poswise_networks[i]
                with tf.variable_scope(poswise_network.variable_scope):
                    original_shape = shape_list(y)
                    y = tf.reshape(y, [-1, self._hparams.dim])
                    if pad_remover:
                        y = tf.expand_dims(pad_remover.remove(y), axis=0)
                        # [1, batch_size*seq_length, hidden_dim]
                    layer_output = poswise_network(y, mode=mode)
                    sub_output = tf.layers.dropout(
                        layer_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode))
                    if pad_remover:
                        sub_output = tf.reshape(pad_remover.restore(tf.squeeze(\
                            sub_output, axis=0)), original_shape \
                        )
                    else:
                        sub_output = tf.reshape(sub_output, original_shape)

                    x = x + sub_output
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)

        if not self._hparams.use_bert_config:
            x = layers.layer_normalize(x)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        if encoder_output:
            return x

        logits = self._output_layer(x)
        sample_ids = tf.to_int32(tf.argmax(logits, axis=-1))
        probs = ''
        # probs = GumbelSoftmax(self._tau, logits=logits).sample()
        # probs = tf.nn.softmax(logits / self._tau) # vanilla softmax

        rets = CrossGraphTransformerFixedLengthDecoderOutput(
            logits=logits, sample_id=sample_ids, probs=probs)

        return rets
示例#28
0
def _main(_):
    hparams = gan_hyperparams.load_hyperparams()
    train_dataset_hparams, valid_dataset_hparams, test_dataset_hparams, encoder_hparams, \
    decoder_hparams, classifier_hparams, opt_hparams, loss_hparams, d_opt_hparams, args = \
        hparams['train_dataset_hparams'], hparams['eval_dataset_hparams'], \
        hparams['test_dataset_hparams'], hparams['encoder_hparams'], hparams['decoder_hparams'], \
        hparams['classifier_hparams'], hparams['opt_hparams'], \
        hparams['loss_hparams'], hparams['d_opt'], hparams['args']

    # Data
    train_data = tx.data.MonoTextData(train_dataset_hparams)
    valid_data = tx.data.MonoTextData(valid_dataset_hparams)
    test_data = tx.data.MonoTextData(test_dataset_hparams)
    iterator = tx.data.FeedableDataIterator(
        {'train_g': train_data, 'train_d': train_data,
         'val': valid_data, 'test': test_data})

    data_batch = iterator.get_next()
    mask_id = train_data.vocab.token_to_id_map_py['<m>']
    boa_id = train_data.vocab.token_to_id_map_py['<BOA>']
    eoa_id = train_data.vocab.token_to_id_map_py['<EOA>']
    eos_id = train_data.vocab.token_to_id_map_py[SpecialTokens.EOS]
    pad_id = train_data.vocab.token_to_id_map_py['<PAD>']
    template_pack, answer_packs = \
        tx.utils.prepare_template(data_batch, args, mask_id, boa_id, eoa_id, pad_id)

    gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma')
    lambda_g = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g')

    # Model architecture
    embedder = tx.modules.WordEmbedder(vocab_size=train_data.vocab.size,
                                       hparams=args.word_embedding_hparams)
    position_embedder = position_embedders.SinusoidsSegmentalPositionEmbedder()
    encoder = tx.modules.UnidirectionalRNNEncoder(hparams=encoder_hparams)
    decoder = tx.modules.BasicPositionalRNNDecoder(vocab_size=train_data.vocab.size,
                                                   hparams=decoder_hparams,
                                                   position_embedder=position_embedder)
    decoder_initial_state_size = decoder.cell.state_size
    connector = tx.modules.connectors.ForwardConnector(decoder_initial_state_size)

    start_tokens = tf.ones_like(data_batch['length']) * boa_id
    gumbel_helper = tx.modules.GumbelSoftmaxEmbeddingHelper(
        embedder.embedding, start_tokens, eoa_id, gamma)

    # Creates classifier
    classifier = tx.modules.Conv1DClassifier(hparams=classifier_hparams)
    clas_embedder = tx.modules.WordEmbedder(vocab_size=train_data.vocab.size,
                                            hparams=args.word_embedding_hparams)

    cetp_loss, d_class_loss, g_class_loss = None, None, None
    cur_template_pack = template_pack
    for idx, hole in enumerate(answer_packs):
        template = cur_template_pack['templates']
        template_word_embeds = embedder(template)
        template_length = shape_list(template)[1]
        channels = shape_list(template_word_embeds)[2]
        template_pos_embeds = position_embedder(template_length, channels,
                                                cur_template_pack['segment_ids'],
                                                cur_template_pack['offsets'])
        enc_input_embedded = template_word_embeds + template_pos_embeds

        _, ecdr_states = encoder(
            enc_input_embedded,
            sequence_length=data_batch["length"])

        dcdr_init_states = connector(ecdr_states)

        dec_input = hole['text_ids'][:, :-1]
        dec_input_word_embeds = embedder(dec_input)
        decoder.set_segment_id(1)
        dec_input_embedded = dec_input_word_embeds
        outputs, _, _ = decoder(
            initial_state=dcdr_init_states,
            decoding_strategy="train_greedy",
            inputs=dec_input_embedded,
            sequence_length=hole["lengths"] + 1)
        cur_loss = tx.utils.smoothing_cross_entropy(
            outputs.logits,
            hole['text_ids'][:, 1:],
            train_data.vocab.size,
            loss_hparams['label_confidence'],
        )
        cetp_loss = cur_loss if cetp_loss is None \
            else tf.concat([cetp_loss, cur_loss], -1)

        soft_outputs_, _, soft_length_, = decoder(
            helper=gumbel_helper, initial_state=dcdr_init_states)

        # Classification loss for the classifier
        clas_logits, clas_preds = classifier(
            inputs=clas_embedder(ids=hole['text_ids'][:, 1:]),
            sequence_length=hole["lengths"]+1)
        loss_d_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(tf.ones_like(data_batch['length'])), logits=clas_logits)
        d_class_loss = loss_d_clas if d_class_loss is None \
            else tf.concat([d_class_loss, loss_d_clas], -1)

        # Classification loss for the generator, based on soft samples
        soft_logits, soft_preds = classifier(
            inputs=clas_embedder(soft_ids=soft_outputs_.sample_id),
            sequence_length=soft_length_)
        loss_g_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(tf.zeros_like(data_batch['length'])), logits=soft_logits)
        g_class_loss = loss_g_clas if g_class_loss is None \
            else tf.concat([g_class_loss, loss_g_clas], -1)

        cur_template_pack = tx.utils.update_template_pack(cur_template_pack,
                                                          hole['text_ids'][:, 1:],
                                                          mask_id, eoa_id, pad_id)
    cetp_loss = tf.reduce_mean(cetp_loss)
    d_class_loss = tf.reduce_mean(d_class_loss)
    g_class_loss = tf.reduce_mean(g_class_loss)

    global_step = tf.Variable(0, trainable=False)
    if args.learning_rate_strategy == 'static':
        learning_rate = tf.Variable(1e-3, dtype=tf.float32)
    elif args.learning_rate_strategy == 'dynamic':
        fstep = tf.to_float(global_step)
        learning_rate = opt_hparams['lr_constant'] \
                        * args.hidden_dim ** -0.5 \
                        * tf.minimum(fstep ** -0.5, fstep * opt_hparams['warmup_steps'] ** -1.5)
    else:
        raise ValueError('Unknown learning_rate_strategy: %s, expecting one of '
                         '[\'static\', \'dynamic\']' % args.learning_rate_strategy)

    g_loss = cetp_loss + lambda_g * g_class_loss
    g_vars = tx.utils.collect_trainable_variables(
        [embedder, encoder, connector, decoder])
    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate,
        beta1=opt_hparams['Adam_beta1'],
        beta2=opt_hparams['Adam_beta2'],
        epsilon=opt_hparams['Adam_epsilon'],
    )
    train_op = optimizer.minimize(g_loss, global_step, var_list=g_vars)

    d_loss = d_class_loss
    d_vars = tx.utils.collect_trainable_variables([clas_embedder, classifier])
    train_op_d = tx.core.get_train_op(d_loss, d_vars, hparams=d_opt_hparams)

    # Inference
    predictions = []
    cur_test_pack = template_pack
    for idx, hole in enumerate(answer_packs):
        template = cur_test_pack['templates']
        template_word_embeds = embedder(template)
        template_length = shape_list(template)[1]
        channels = shape_list(template_word_embeds)[2]
        template_pos_embeds = position_embedder(template_length, channels,
                                                cur_test_pack['segment_ids'],
                                                cur_test_pack['offsets'])
        enc_input_embedded = template_word_embeds + template_pos_embeds

        _, ecdr_states = encoder(
            enc_input_embedded,
            sequence_length=data_batch["length"])

        dcdr_init_states = connector(ecdr_states)

        decoder.set_segment_id(1)
        outputs_infer, _, _ = decoder(
            decoding_strategy="infer_positional",
            start_tokens=start_tokens,
            end_token=eoa_id,
            embedding=embedder,
            initial_state=dcdr_init_states)
        predictions.append(outputs_infer.sample_id)
        cur_test_pack = tx.utils.update_template_pack(cur_test_pack,
                                                      outputs_infer.sample_id,
                                                      mask_id, eoa_id, pad_id)

    eval_saver = tf.train.Saver(max_to_keep=5)

    def _train_epochs(session, cur_epoch, gamma_, lambda_g_):
        loss_lists, ppl_lists = [], []
        while True:
            try:
                fetches_d = {
                    'train_op_d': train_op_d,
                    'd_loss': d_loss
                }
                feed_d = {
                    iterator.handle: iterator.get_handle(sess, 'train_d'),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    tx.context.global_mode(): tf.estimator.ModeKeys.TRAIN
                }
                rtns_d = session.run(fetches_d, feed_dict=feed_d)
                d_loss_ = rtns_d['d_loss']
                fetches_g = {
                    'template': template_pack,
                    'holes': answer_packs,
                    'train_op': train_op,
                    'step': global_step,
                    'lr': learning_rate,
                    'loss': cetp_loss,
                    'g_loss': g_loss
                }
                feed_g = {
                    iterator.handle: iterator.get_handle(sess, 'train_g'),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    tx.context.global_mode(): tf.estimator.ModeKeys.TRAIN
                }
                rtns = session.run(fetches_g, feed_dict=feed_g)
                step, template_, holes_, cetp_loss_, g_loss_ = \
                    rtns['step'], rtns['template'], rtns['holes'], rtns['loss'], rtns['g_loss']
                ppl = np.exp(cetp_loss_)
                if step % 200 == 1:
                    rst = 'step:%s source:%s g_loss:%f d_loss:%f ppl:%f lr:%f' % \
                          (step, template_['text_ids'].shape, g_loss_, d_loss_, ppl, rtns['lr'])
                    print(rst)
                loss_lists.append(g_loss_)
                ppl_lists.append(ppl)
            except tf.errors.OutOfRangeError:
                break
        return loss_lists[::50], ppl_lists[::50]

    def _test_epoch(cur_sess, cur_epoch, gamma_, lambda_g_, mode='test'):
        def _id2word_map(id_arrays):
            return [' '.join([train_data.vocab._id_to_token_map_py[i]
                              for i in sent]) for sent in id_arrays]

        templates_list, targets_list, hypothesis_list = [], [], []
        cnt = 0
        loss_lists, ppl_lists = [], []
        while True:
            try:
                fetches = {
                    'data_batch': data_batch,
                    'predictions': predictions,
                    'template': template_pack,
                    'step': global_step,
                    'loss': cetp_loss
                }
                feed = {
                    iterator.handle: iterator.get_handle(sess, mode),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                }
                rtns = cur_sess.run(fetches, feed_dict=feed)
                real_templates_, templates_, targets_, predictions_ = \
                    rtns['template']['templates'], rtns['template']['text_ids'], \
                    rtns['data_batch']['text_ids'], rtns['predictions']
                loss = rtns['loss']
                ppl = np.exp(loss)
                loss_lists.append(loss)
                ppl_lists.append(ppl)

                filled_templates = \
                    tx.utils.fill_template(template_pack=rtns['template'],
                                           predictions=rtns['predictions'],
                                           eoa_id=eoa_id, pad_id=pad_id, eos_id=eos_id)

                templates, targets, generateds = _id2word_map(real_templates_.tolist()), \
                                                 _id2word_map(targets_), \
                                                 _id2word_map(filled_templates)

                for template, target, generated in zip(templates, targets, generateds):
                    template = template.split('<EOS>')[0].split('<PAD>')[0].strip().split()
                    target = target.split('<EOS>')[0].split('<PAD>')[0].strip().split()
                    got = generated.split('<EOS>')[0].split('<PAD>')[0].strip().split()
                    templates_list.append(template)
                    targets_list.append(target)
                    hypothesis_list.append(got)

                cnt += 1
                if mode is not 'test' and cnt >= 60:
                    break
            except tf.errors.OutOfRangeError:
                break

        avg_loss, avg_ppl = np.mean(loss_lists), np.mean(ppl_lists)
        outputs_tmp_filename = args.log_dir + 'epoch{}.beam{}.outputs.tmp'. \
            format(cur_epoch, args.beam_width)
        template_tmp_filename = args.log_dir + 'epoch{}.beam{}.templates.tmp'. \
            format(cur_epoch, args.beam_width)
        refer_tmp_filename = os.path.join(args.log_dir, 'eval_reference.tmp')
        with codecs.open(outputs_tmp_filename, 'w+', 'utf-8') as tmpfile, \
                codecs.open(template_tmp_filename, 'w+', 'utf-8') as tmptpltfile, \
                codecs.open(refer_tmp_filename, 'w+', 'utf-8') as tmpreffile:
            for hyp, tplt, tgt in zip(hypothesis_list, templates_list, targets_list):
                tmpfile.write(' '.join(hyp) + '\n')
                tmptpltfile.write(' '.join(tplt) + '\n')
                tmpreffile.write(' '.join(tgt) + '\n')
        eval_bleu = float(100 * bleu_tool.bleu_wrapper(
            refer_tmp_filename, outputs_tmp_filename, case_sensitive=True))
        template_bleu = float(100 * bleu_tool.bleu_wrapper(
            refer_tmp_filename, template_tmp_filename, case_sensitive=True))
        print('epoch:{} {}_bleu:{} template_bleu:{} {}_loss:{} {}_ppl:{} '.
              format(cur_epoch, mode, eval_bleu, template_bleu, mode, avg_loss, mode, avg_ppl))
        os.remove(outputs_tmp_filename)
        os.remove(template_tmp_filename)
        os.remove(refer_tmp_filename)
        if args.save_eval_output:
            result_filename = \
                args.log_dir + 'epoch{}.beam{}.{}.results.bleu{:.3f}' \
                    .format(cur_epoch, args.beam_width, mode, eval_bleu)
            with codecs.open(result_filename, 'w+', 'utf-8') as resultfile:
                for tmplt, tgt, hyp in zip(templates_list, targets_list, hypothesis_list):
                    resultfile.write("- template: " + ' '.join(tmplt) + '\n')
                    resultfile.write("- expected: " + ' '.join(tgt) + '\n')
                    resultfile.write('- got:      ' + ' '.join(hyp) + '\n\n')
        return {
            'eval': eval_bleu,
            'template': template_bleu
        }, avg_ppl

    def _draw_train_loss(epoch, loss_list, mode):
        plt.figure(figsize=(14, 10))
        plt.plot(loss_list, '--', linewidth=1, label='loss trend')
        plt.ylabel('%s till epoch %s' % (mode, epoch))
        plt.xlabel('every 50 steps, present_rate=%f' % args.present_rate)
        plt.savefig(args.log_dir + '/img/%s_curve.png' % mode)
        plt.close('all')

    def _draw_bleu(epoch, test_bleu, tplt_bleu, train_bleu, train_tplt_bleu):
        plt.figure(figsize=(14, 10))
        legends = []
        plt.plot(test_bleu, '--', linewidth=1, label='test bleu')
        plt.plot(tplt_bleu, '--', linewidth=1, label='template bleu')
        legends.extend(['test bleu', 'template bleu'])
        plt.ylabel('bleu till epoch {}'.format(epoch))
        plt.xlabel('every epoch')
        plt.legend(legends, loc='upper left')
        plt.savefig(args.log_dir + '/img/bleu.png')

        plt.figure(figsize=(14, 10))
        legends = []
        plt.plot(train_bleu, '--', linewidth=1, label='train bleu')
        plt.plot(train_tplt_bleu, '--', linewidth=1, label='train template bleu')
        legends.extend(['train bleu', 'train template bleu'])
        plt.ylabel('bleu till epoch {}'.format(epoch))
        plt.xlabel('every epoch')
        plt.legend(legends, loc='upper left')
        plt.savefig(args.log_dir + '/img/train_bleu.png')
        plt.close('all')

    config_ = tf.ConfigProto(allow_soft_placement=True)
    config_.gpu_options.allow_growth = True

    with tf.Session(config=config_) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        iterator.initialize_dataset(sess)

        loss_list, ppl_list, test_ppl_list = [], [], []
        test_bleu, tplt_bleu, train_bleu, train_tplt_bleu = [], [], [], []
        gamma_, lambda_g_ = 1., 0.
        if args.running_mode == 'train_and_evaluate':
            for epoch in range(70, args.max_train_epoch):
                # Anneals the gumbel-softmax temperature
                if epoch > args.pretrain_epoch:
                    gamma_ = max(0.001, gamma_ * args.gamma_decay)
                    lambda_g_ = args.lambda_g

                # bleu on test set and train set
                if epoch % args.bleu_interval == 0 or epoch == args.max_train_epoch - 1:
                    iterator.restart_dataset(sess, 'test')
                    bleu_scores, test_ppl = _test_epoch(sess, epoch, gamma_, lambda_g_)
                    test_bleu.append(bleu_scores['eval'])
                    tplt_bleu.append(bleu_scores['template'])
                    test_ppl_list.append(test_ppl)
                    _draw_train_loss(epoch, test_ppl_list, mode='test_perplexity')

                    iterator.restart_dataset(sess, 'train_g')
                    train_bleu_scores, _ = _test_epoch(sess, epoch, gamma_, lambda_g_, mode='train_g')
                    train_bleu.append(train_bleu_scores['eval'])
                    train_tplt_bleu.append(train_bleu_scores['template'])
                    _draw_bleu(epoch, test_bleu, tplt_bleu, train_bleu, train_tplt_bleu)
                    eval_saver.save(sess, args.log_dir + 'my-model-latest.ckpt')

                # train
                iterator.restart_dataset(sess, ['train_g', 'train_d'])
                losses, ppls = _train_epochs(sess, epoch, gamma_, lambda_g_)
                loss_list.extend(losses)
                ppl_list.extend(ppls)
                _draw_train_loss(epoch, loss_list, mode='train_loss')
                _draw_train_loss(epoch, ppl_list, mode='perplexity')
                sys.stdout.flush()

                if epoch == args.pretrain_epoch:
                    eval_saver.save(sess, args.log_dir + 'pretrained-model.ckpt')
示例#29
0
    def _build(
            self,  # pylint: disable=arguments-differ
            memory=None,
            memory_sequence_length=None,
            memory_attention_bias=None,
            inputs=None,
            sequence_length=None,
            decoding_strategy='train_greedy',
            beam_width=None,
            length_penalty=0.,
            start_tokens=None,
            end_token=None,
            context=None,
            context_sequence_length=None,
            softmax_temperature=None,
            max_decoding_length=None,
            impute_finished=False,
            helper=None,
            mode=None):
        """Performs decoding.

        The interface is very similar to that of RNN decoders
        (:meth:`texar.modules.RNNDecoderBase._build`). In particular,
        the function provides **3 ways** to specify the decoding method, with
        varying flexibility:

        1. The :attr:`decoding_strategy` argument.

            - **"train_greedy"**: decoding in teacher-forcing fashion (i.e.,
              feeding ground truth to decode the next step), and for each step
              sample is obtained by taking the `argmax` of logits.
              Argument :attr:`inputs` is required for this strategy.
              :attr:`sequence_length` is optional.
            - **"infer_greedy"**: decoding in inference fashion (i.e., feeding
              `generated` sample to decode the next step), and for each step
              sample is obtained by taking the `argmax` of logits.
              Arguments :attr:`(start_tokens, end_token)` are
              required for this strategy, and argument
              :attr:`max_decoding_length` is optional.
            - **"infer_sample"**: decoding in inference fashion, and for each
              step sample is obtained by `random sampling` from the logits.
              Arguments :attr:`(start_tokens, end_token)` are required for this
              strategy, and argument :attr:`max_decoding_length` is optional.

          This argument is used only when arguments :attr:`helper` and
          :attr:`beam_width` are both `None`.

        2. The :attr:`helper` argument: An instance of subclass of
           :tf_main:`tf.contrib.seq2seq.Helper <contrib/seq2seq/Helper>`.
           This provides a superset of decoding strategies than above.
           The interface is the same as in RNN decoders.
           Please refer to :meth:`texar.modules.RNNDecoderBase._build` for
           detailed usage and examples.

           Note that, here, though using a :tf_main:`TrainingHelper
           <contrib/seq2seq/TrainingHelper>` corresponding to the
           "train_greedy" strategy above, the implementation is *slower* than
           directly setting `decoding_strategy="train_greedy"` (though the
           output results are the same).

           Argument :attr:`max_decoding_length` is optional.

        3. **Beam search**: set :attr:`beam_width` to use beam search decoding.
           Arguments :attr:`(start_tokens, end_token)` are required,
           and argument :attr:`max_decoding_length` is optional.

        Args:
            memory (optional): The memory to attend, e.g., the output of an RNN encoder.
                A Tensor of shape `[batch_size, memory_max_time, dim]`.
            memory_sequence_length (optional): A Tensor of shape `[batch_size]`
                containing the sequence lengths for the batch entries in
                memory. Used to create attention bias of
                :attr:`memory_attention_bias` is not given. Ignored if
                `memory_attention_bias` is provided.
            memory_attention_bias (optional): A Tensor of shape
                `[batch_size, num_heads, memory_max_time, dim]`.
                An attention bias typically sets the value of a padding
                position to a large negative value for masking. If not given,
                :attr:`memory_sequence_length` is used to automatically
                create an attention bias.
            inputs (optional): Input tensor for teacher forcing decoding, of
                shape `[batch_size, target_max_time, emb_dim]` containing the
                target sequence word embeddings.
                Used when :attr:`decoding_strategy` is set to "train_greedy".
            sequence_length (optional): A Tensor of shape `[batch_size]`,
                containing the sequence length of :attr:`inputs`.
                Tokens beyond the respective sequence length are masked out.
                Used when :attr:`decoding_strategy` is set to
                "train_greedy".
            decoding_strategy (str): A string specifying the decoding
                strategy, including "train_greedy", "infer_greedy",
                "infer_sample".
                Different arguments are required based on the
                strategy. See above for details. Ignored if
                :attr:`beam_width` or :attr:`helper` is set.
            beam_width (int): Set to use beam search. If given,
                :attr:`decoding_strategy` is ignored.
            length_penalty (float): Length penalty coefficient used in beam search
                decoding. Refer to https://arxiv.org/abs/1609.08144
                for more details.
                It Should be larger if longer sentences are wanted.
            start_tokens (optional): An int Tensor of shape `[batch_size]`,
                containing the start tokens.
                Used when :attr:`decoding_strategy` = "infer_greedy" or
                "infer_sample", or :attr:`beam_width` is set.
                Ignored when context is set.
            end_token (optional): An int 0D Tensor, the token that marks end
                of decoding.
                Used when :attr:`decoding_strategy` = "infer_greedy" or
                "infer_sample", or :attr:`beam_width` is set.
            context (optional): An int Tensor of shape `[batch_size, length]`,
                containing the starting tokens for decoding.
                If context is set, the start_tokens will be ignored.
            context_sequence_length (optional): specify the length of context.
            softmax_temperature (optional): A float 0D Tensor, value to divide
                the logits by before computing the softmax. Larger values
                (above 1.0) result in more random samples. Must > 0. If `None`,
                1.0 is used.
                Used when :attr:`decoding_strategy` = "infer_sample"`.
            max_decoding_length (optional): An int scalar Tensor indicating
                the maximum allowed number of decoding steps.
                If `None` (default), use "max_decoding_length" defined in
                :attr:`hparams`. Ignored in "train_greedy" decoding.
            impute_finished (bool): If `True`, then states for batch
                entries which are marked as finished get copied through and
                the corresponding outputs get zeroed out.  This causes some
                slowdown at each time step, but ensures that the final state
                and outputs have the correct values and that backprop ignores
                time steps that were marked as finished. Ignored in
                "train_greedy" decoding.
            helper (optional): An instance of
                :tf_main:`Helper <contrib/seq2seq/Helper>` that defines the
                decoding strategy. If given, :attr:`decoding_strategy` is
                ignored.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. Controls dropout mode.
                If `None` (default), :func:`texar.global_mode`
                is used.

        Returns:

            - For **"train_greedy"** decoding, returns an instance of \
            :class:`~texar.modules.TransformerDecoderOutput` which contains\
            `sample_id` and `logits`.

            - For **"infer_greedy"** and **"infer_sample"** decoding or\
            decoding with :attr:`helper`, returns\
            a tuple `(outputs, sequence_lengths)`, where `outputs` is an \
            instance of :class:`~texar.modules.TransformerDecoderOutput` as\
            in "train_greedy", and `sequence_lengths` is a Tensor of shape\
            `[batch_size]` containing the length of each sample.

            - For **beam search** decoding, returns a `dict` containing keys\
            "sample_id" and "log_prob".

                - **"sample_id"** is an int Tensor of shape \
                `[batch_size, max_time, beam_width]` containing generated\
                token indexes. `sample_id[:,:,0]` is the highest-probable \
                sample.
                - **"log_prob"** is a float Tensor of shape \
                `[batch_size, beam_width]` containing the log probability \
                of each sequence sample.
        """

        if memory is not None:
            if memory_attention_bias is None:
                if memory_sequence_length is None:
                    raise ValueError("`memory_sequence_length` is required if "
                                     "`memory_attention_bias` is not given.")

                enc_padding = 1 - tf.sequence_mask(memory_sequence_length,
                                                   tf.shape(memory)[1],
                                                   dtype=tf.float32)
                memory_attention_bias = attn.attention_bias_ignore_padding(
                    enc_padding)

        # record the context, which will be used in step function
        # for dynamic_decode
        if context is not None:
            start_tokens = context[:, 0]
            self.context = context[:, 1:]
            self.context_sequence_length = context_sequence_length - 1
        else:
            self.context = None

        if helper is None and beam_width is None and \
                decoding_strategy == 'train_greedy': # Teacher-forcing
            if sequence_length is not None:
                inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)

            decoder_self_attention_bias = (attn.attention_bias_lower_triangle(
                shape_list(inputs)[1]))
            if self._hparams.scale_embeds:
                target_inputs = inputs * self._hparams.dim**0.5
            else:
                target_inputs = inputs

            _, lengths, _ = shape_list(target_inputs)
            positions = tf.expand_dims(tf.range(lengths, dtype=tf.int32), 0)
            pos_embeds = self.position_embedder(positions)

            inputs = target_inputs + pos_embeds

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias=decoder_self_attention_bias,
                memory_attention_bias=memory_attention_bias,
                cache=None,
                mode=mode)
            logits = self.output_layer(decoder_output)
            preds = tf.to_int32(tf.argmax(logits, axis=-1))
            rets = TransformerDecoderOutput(logits=logits, sample_id=preds)

        else:
            if max_decoding_length is None:
                max_decoding_length = self._hparams.max_decoding_length

            self._inputs_to_outputs = self._inputs_to_outputs_fn(
                max_decoding_length + 1)

            if beam_width is None:  #Inference-like decoding
                # Prepare helper
                if helper is not None:
                    # ignore `decoding_strategy`
                    pass
                else:
                    if decoding_strategy == "infer_greedy":
                        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                            self._embedding, start_tokens, end_token)
                    elif decoding_strategy == "infer_sample":
                        helper = tf.contrib.seq2seq.SampleEmbeddingHelper(
                            self._embedding, start_tokens, end_token,
                            softmax_temperature)
                    else:
                        raise ValueError(
                            "Unknown decoding strategy: {}".format(
                                decoding_strategy))
                self._helper = helper

                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=False)
                if context is not None:
                    self.context = tf.pad(
                        self.context,
                        [[0, 0],
                         [0, max_decoding_length - tf.shape(self.context)[1]]])

                outputs, cache, sequence_lengths = dynamic_decode(
                    decoder=self,
                    impute_finished=impute_finished,
                    maximum_iterations=max_decoding_length,
                    output_time_major=False,
                    scope=self.variable_scope)

                if context is not None:
                    # Here the length of sample_id will be larger than that
                    # of logit by 1, because there will be a additional
                    # start_token in the returned sample_id.
                    # the start_id should be the first token of the
                    # given context
                    outputs = TransformerDecoderOutput(
                        logits=outputs.logits,
                        sample_id=tf.concat([
                            tf.expand_dims(start_tokens, 1), outputs.sample_id
                        ],
                                            axis=1))
                    sequence_lengths = sequence_lengths + 1
                rets = outputs, sequence_lengths

            else:  #Beam-search decoding
                # ignore `decoding_strategy`
                # assume `helper` is not set
                if helper is not None:
                    raise ValueError("Must not set 'beam_width' and 'helper' "
                                     "simultaneously.")
                _batch_size = tf.shape(start_tokens)[0]
                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=True,
                                               batch_size=_batch_size)

                # The output format is different when running beam search
                sample_id, log_prob = self._beam_decode(
                    start_tokens,
                    end_token,
                    beam_width=beam_width,
                    length_penalty=length_penalty,
                    decode_length=max_decoding_length,
                )
                rets = {'sample_id': sample_id, 'log_prob': log_prob}

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return rets
def beam_search(symbols_to_logits_fn,
                initial_ids,
                beam_size,
                decode_length,
                vocab_size,
                alpha,
                eos_id,
                states=None,
                stop_early=True):
    """Beam search with length penalties.

    Requires a function that can take the currently decoded sybmols and
    return the logits for the next symbol. The implementation is inspired
    by https://arxiv.org/abs/1609.08144.

    When running, the beam search steps can be visualized by using tfdbg to
    watch the operations generating the output ids for each beam step.
    These operations have the pattern:
        (alive|finished)_topk_(seq,scores)

    Operations marked `alive` represent the new beam sequences that will be
    processed in the next step.    Operations marked `finished` represent
    the completed beam sequences, which may be padded with 0s if no beams
    finished.

    Operations marked `seq` store the full beam sequence for the time step.
    Operations marked `scores` store the sequence's final log scores.

    The beam search steps will be processed sequentially in order, so when
    capturing observed from these operations, tensors, clients can make
    assumptions about which step is being recorded.

    WARNING: Assumes 2nd dimension of tensors in `states` and not
    invariant, this means that the shape of the 2nd dimension of these
    tensors will not be available (i.e. set to None) inside
    symbols_to_logits_fn.

    Args:
        symbols_to_logits_fn: Interface to the model, to provide logits.
            Should take [batch_size, decoded_ids] and return
            [batch_size, vocab_size]
        initial_ids: Ids to start off the decoding, this will be the first
            thing handed to symbols_to_logits_fn (after expanding to beam size)
            [batch_size]
        beam_size: Size of the beam.
        decode_length: Number of steps to decode for.
        vocab_size: Size of the vocab, must equal the size of the logits
            returned by symbols_to_logits_fn
        alpha: alpha for length penalty.
        states: dict (possibly nested) of decoding states.
        eos_id: ID for end of sentence.
        stop_early: a boolean - stop once best sequence is provably
            determined.

    Returns:
        Tuple of
        (decoded beams [batch_size, beam_size, decode_length]
         decoding probablities [batch_size, beam_size])
    """
    batch_size = shape_list(initial_ids)[0]

    # Assume initial_ids are prob 1.0
    initial_log_probs = tf.constant([[0.] + [-float("inf")] * (beam_size - 1)])
    # Expand to beam_size (batch_size, beam_size)
    alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1])

    # Expand each batch and state to beam_size
    alive_seq = _expand_to_beam_size(initial_ids, beam_size)
    alive_seq = tf.expand_dims(alive_seq, axis=2)
    #(batch_size, beam_size, 1)
    if states:
        states = nest.map_structure(
            lambda state: _expand_to_beam_size(state, beam_size), states)
    else:
        states = {}

    # Finished will keep track of all the sequences that have finished so
    # far
    # Finished log probs will be negative infinity in the beginning
    # finished_flags will keep track of booleans
    finished_seq = tf.zeros(shape_list(alive_seq), tf.int32)
    # Setting the scores of the initial to negative infinity.
    finished_scores = tf.ones([batch_size, beam_size]) * -INF
    finished_flags = tf.zeros([batch_size, beam_size], tf.bool)

    def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq,
                      curr_scores, curr_finished):
        """Given sequences and scores, will gather the top k=beam size
        sequences.

        Args:
            finished_seq: Current finished sequences.
                [batch_size, beam_size, current_decoded_length]
            finished_scores: scores for each of these sequences.
                [batch_size, beam_size]
            finished_flags: finished bools for each of these sequences.
                [batch_size, beam_size]
            curr_seq: current topk sequence that has been grown by one
                position.
                [batch_size, beam_size, current_decoded_length]
            curr_scores: scores for each of these sequences. [batch_size,
                beam_size]
            curr_finished: Finished flags for each of these sequences.
                [batch_size, beam_size]

        Returns:
            Tuple of
                (Topk sequences based on scores,
                 log probs of these sequences,
                 Finished flags of these sequences)
        """
        # First append a column of 0'ids to finished to make the same
        # length with finished scores
        finished_seq = tf.concat(
            [finished_seq,
             tf.zeros([batch_size, beam_size, 1], tf.int32)],
            axis=2)

        # Set the scores of the unfinished seq in curr_seq to large
        # negative values
        curr_scores += (1. - tf.to_float(curr_finished)) * -INF
        # concatenating the sequences and scores along beam axis
        curr_finished_seq = tf.concat([finished_seq, curr_seq], axis=1)
        curr_finished_scores = tf.concat([finished_scores, curr_scores],
                                         axis=1)
        curr_finished_flags = tf.concat([finished_flags, curr_finished],
                                        axis=1)
        return compute_topk_scores_and_seq(curr_finished_seq,
                                           curr_finished_scores,
                                           curr_finished_scores,
                                           curr_finished_flags, beam_size,
                                           batch_size, "grow_finished")

    def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
                   states):
        """Given sequences and scores, will gather the top k=beam size
        sequences.

        Args:
            curr_seq: current topk sequence that has been grown by one
                position.
                [batch_size, beam_size, i+1]
            curr_scores: scores for each of these sequences. [batch_size,
                beam_size]
            curr_log_probs: log probs for each of these sequences.
                [batch_size, beam_size]
            curr_finished: Finished flags for each of these sequences.
                [batch_size, beam_size]
            states: dict (possibly nested) of decoding states.

        Returns:
            Tuple of
                (Topk sequences based on scores,
                 log probs of these sequences,
                 Finished flags of these sequences)
        """
        # Set the scores of the finished seq in curr_seq to large negative
        # values
        curr_scores += tf.to_float(curr_finished) * -INF
        return compute_topk_scores_and_seq(curr_seq, curr_scores,
                                           curr_log_probs, curr_finished,
                                           beam_size, batch_size, "grow_alive",
                                           states)

    def grow_topk(i, alive_seq, alive_log_probs, states):
        r"""Inner beam seach loop.

        This function takes the current alive sequences, and grows them to
        topk sequences where k = 2*beam. We use 2*beam because, we could
        have beam_size number of sequences that might hit <EOS> and there
        will be no alive sequences to continue. With 2*beam_size, this
        will not happen. This relies on the assumption the vocab size is >
        beam size. If this is true, we'll have at least beam_size non
        <EOS> extensions if we extract the next top 2*beam words.
        Length penalty is given by = (5+len(decode)/6) ^ -\alpha.
        Pls refer to https://arxiv.org/abs/1609.08144.

        Args:
            i: loop index
            alive_seq: Topk sequences decoded so far [batch_size,
                beam_size, i+1]
            alive_log_probs: probabilities of these sequences.
                [batch_size, beam_size]
            states: dict (possibly nested) of decoding states.

        Returns:
            Tuple of
                (Topk sequences extended by the next word,
                 The log probs of these sequences,
                 The scores with length penalty of these sequences,
                 Flags indicating which of these sequences have finished
                 decoding, dict of transformed decoding states)
        """
        # Get the logits for all the possible next symbols
        flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1])

        # (batch_size * beam_size, decoded_length)
        if states:
            flat_states = nest.map_structure(_merge_beam_dim, states)
            flat_logits, flat_states = symbols_to_logits_fn(
                flat_ids, i, flat_states)
            states = nest.map_structure(
                lambda t: _unmerge_beam_dim(t, batch_size, beam_size),
                flat_states)
        else:
            flat_logits = symbols_to_logits_fn(flat_ids)
        logits = tf.reshape(flat_logits, [batch_size, beam_size, -1])

        # Convert logits to normalized log probs
        candidate_log_probs = log_prob_from_logits(logits)

        # Multiply the probabilites by the current probabilites of the
        # beam.
        # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
        log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs,
                                                         axis=2)

        length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha)

        curr_scores = log_probs / length_penalty
        # Flatten out (beam_size, vocab_size) probs in to a list of
        # possibilites
        flat_curr_scores = tf.reshape(curr_scores,
                                      [-1, beam_size * vocab_size])

        topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2)

        # Recovering the log probs because we will need to send them back
        topk_log_probs = topk_scores * length_penalty

        # Work out what beam the top probs are in.
        topk_beam_index = topk_ids // vocab_size
        topk_ids %= vocab_size  # Unflatten the ids

        # The next three steps are to create coordinates for tf.gather_nd
        # to pull out the correct seqences from id's that we need to grow.
        # We will also use the coordinates to gather the booleans of the
        # beam items that survived.
        batch_pos = compute_batch_indices(batch_size, beam_size * 2)

        # top beams will give us the actual coordinates to do the gather.
        # stacking will create a tensor of dimension batch * beam * 2,
        # where the last dimension contains the i,j gathering coordinates.
        topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2)

        # Gather up the most probable 2*beams both for the ids and
        # finished_in_alive bools
        topk_seq = tf.gather_nd(alive_seq, topk_coordinates)
        if states:
            states = nest.map_structure(
                lambda state: tf.gather_nd(state, topk_coordinates), states)

        # Append the most probable alive
        topk_seq = tf.concat(
            [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)

        topk_finished = tf.equal(topk_ids, eos_id)

        return topk_seq, topk_log_probs, topk_scores, topk_finished, states

    def inner_loop(i, alive_seq, alive_log_probs, finished_seq,
                   finished_scores, finished_flags, states):
        """Inner beam seach loop.

        There are three groups of tensors, alive, finished, and topk.
        The alive group contains information about the current alive
        sequences. The topk group contains information about alive + topk
        current decoded words the finished group contains information
        about finished sentences, that is, the ones that have decoded to
        <EOS>. These are what we return.
        The general beam search algorithm is as follows:
        While we haven't terminated (pls look at termination condition)
            1. Grow the current alive to get beam*2 topk sequences
            2. Among the topk, keep the top beam_size ones that haven't
            reached EOS into alive
            3. Among the topk, keep the top beam_size ones have reached
            EOS into finished
        Repeat
        To make things simple with using fixed size tensors, we will end
        up inserting unfinished sequences into finished in the beginning.
        To stop that we add -ve INF to the score of the unfinished
        sequence so that when a true finished sequence does appear, it
        will have a higher score than all the unfinished ones.

        Args:
            i: loop index
            alive_seq: Topk sequences decoded so far [batch_size,
                beam_size, i+1]
            alive_log_probs: probabilities of the beams. [batch_size,
                beam_size]
            finished_seq: Current finished sequences.
                [batch_size, beam_size, i+1]
            finished_scores: scores for each of these sequences.
                [batch_size, beam_size]
            finished_flags: finished bools for each of these sequences.
                [batch_size, beam_size]
            states: dict (possibly nested) of decoding states.

        Returns:
            Tuple of
                (Incremented loop index
                 New alive sequences,
                 Log probs of the alive sequences,
                 New finished sequences,
                 Scores of the new finished sequences,
                 Flags inidicating which sequence in finished as reached
                 EOS,
                 dict of final decoding states)
        """

        # Each inner loop, we carry out three steps:
        # 1. Get the current topk items.
        # 2. Extract the ones that have finished and haven't finished
        # 3. Recompute the contents of finished based on scores.
        topk_seq, topk_log_probs, topk_scores, topk_finished, states =\
            grow_topk(i, alive_seq, alive_log_probs, states)
        alive_seq, alive_log_probs, _, states = grow_alive(
            topk_seq, topk_scores, topk_log_probs, topk_finished, states)
        finished_seq, finished_scores, finished_flags, _ = grow_finished(
            finished_seq, finished_scores, finished_flags, topk_seq,
            topk_scores, topk_finished)

        return (i + 1, alive_seq, alive_log_probs, finished_seq,
                finished_scores, finished_flags, states)

    def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
                     finished_scores, finished_in_finished, unused_states):
        """Checking termination condition.

        We terminate when we decoded up to decode_length or the lowest
        scoring item in finished has a greater score that the higest prob
        item in alive divided by the max length penalty

        Args:
            i: loop index
            alive_log_probs: probabilities of the beams. [batch_size,
                beam_size]
            finished_scores: scores for each of these sequences.
                [batch_size, beam_size]
            finished_in_finished: finished bools for each of these
                sequences. [batch_size, beam_size]

        Returns:
            Bool.
        """
        if not stop_early:
            return tf.less(i, decode_length)
        max_length_penalty = tf.pow(((5. + tf.to_float(decode_length)) \
            / 6.), alpha)
        # The best possible score of the most likley alive sequence
        lower_bound_alive_scores = alive_log_probs[:, 0] /\
            max_length_penalty

        # Now to compute the lowest score of a finished sequence in
        # finished
        # If the sequence isn't finished, we multiply it's score by 0.
        # since scores are all -ve, taking the min will give us the score
        # of the lowest finished item.
        lowest_score_of_fininshed_in_finished = tf.reduce_min(
            finished_scores * tf.to_float(finished_in_finished), axis=1)
        # If none of the sequences have finished, then the min will be 0
        # and we have to replace it by -ve INF if it is. The score of any
        # seq in alive will be much higher than -ve INF and the
        # termination condition will not be met.
        lowest_score_of_fininshed_in_finished += (
            (1. - tf.to_float(tf.reduce_any(finished_in_finished, 1))) * -INF)

        bound_is_met = tf.reduce_all(
            tf.greater(lowest_score_of_fininshed_in_finished,
                       lower_bound_alive_scores))

        return tf.logical_and(tf.less(i, decode_length),
                              tf.logical_not(bound_is_met))

    (_, alive_seq, alive_log_probs, finished_seq, finished_scores,
     finished_flags, _) = tf.while_loop(
         _is_finished,
         inner_loop, [
             tf.constant(0), alive_seq, alive_log_probs, finished_seq,
             finished_scores, finished_flags, states
         ],
         shape_invariants=[
             tf.TensorShape([]),
             tf.TensorShape([None, None, None]),
             alive_log_probs.get_shape(),
             tf.TensorShape([None, None, None]),
             finished_scores.get_shape(),
             finished_flags.get_shape(),
             nest.map_structure(get_state_shape_invariants, states),
         ],
         parallel_iterations=1,
         back_prop=False)

    alive_seq.set_shape((None, beam_size, None))
    finished_seq.set_shape((None, beam_size, None))

    # Accounting for corner case: It's possible that no sequence in alive
    # for a particular batch item ever reached EOS. In that case, we
    # should just copy the contents of alive for that batch item. tf
    # reduce_any(finished_flags, 1)
    # if 0, means that no sequence for that batch index had reached EOS.
    # We need to do the same for the scores as well.
    finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq,
                            alive_seq)
    finished_scores = tf.where(tf.reduce_any(finished_flags, 1),
                               finished_scores, alive_log_probs)
    return finished_seq, finished_scores