Exemplo n.º 1
0
  def _decode(self, input_dict):
    """
    Decodes representation into data
    :param input_dict: Python dictionary with inputs to decoder
    Must define:
      * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
                     or [time, batch_size, dim]
      * src_lengths - decoder input lengths Tensor of shape [batch_size]
    Does not need tgt_inputs and tgt_lengths
    :return: a Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim] or
                        [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time] or
                                 [time, batch_size]
    """
    encoder_outputs = input_dict['encoder_output']['outputs']
    enc_src_lengths = input_dict['encoder_output']['src_lengths']

    self._dec_emb_w = tf.get_variable(
      name='DecoderEmbeddingMatrix',
      shape=[self._tgt_vocab_size, self._tgt_emb_size],
      dtype=tf.float32
    )

    self._output_projection_layer = tf.layers.Dense(
      self._tgt_vocab_size, use_bias=False,
    )

    cell_params = copy.deepcopy(self.params)
    cell_params["num_units"] = self.params['decoder_cell_units']

    if self._mode == "train":
      dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
      dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
    else:
      dp_input_keep_prob = 1.0
      dp_output_keep_prob = 1.0

    if self.params['attention_type'].startswith('gnmt'):
      residual_connections = False
      wrap_to_multi_rnn = False
    else:
      residual_connections = self.params['decoder_use_skip_connections']
      wrap_to_multi_rnn = True

    self._decoder_cells = create_rnn_cell(
      cell_type=self.params['decoder_cell_type'],
      cell_params=cell_params,
      num_layers=self.params['decoder_layers'],
      dp_input_keep_prob=dp_input_keep_prob,
      dp_output_keep_prob=dp_output_keep_prob,
      residual_connections=residual_connections,
      wrap_to_multi_rnn=wrap_to_multi_rnn,
    )

    tiled_enc_outputs = tf.contrib.seq2seq.tile_batch(
      encoder_outputs,
      multiplier=self._beam_width,
    )
    tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch(
      enc_src_lengths,
      multiplier=self._beam_width,
    )
    attention_mechanism = self._build_attention(
      tiled_enc_outputs,
      tiled_enc_src_lengths,
    )

    if self.params['attention_type'].startswith('gnmt'):
      attention_cell = self._decoder_cells.pop(0)
      attention_cell = AttentionWrapper(
        attention_cell,
        attention_mechanism=attention_mechanism,
        attention_layer_size=None,  # don't use attention layer.
        output_attention=False,
        name="gnmt_attention")
      attentive_decoder_cell = GNMTAttentionMultiCell(
        attention_cell, self._add_residual_wrapper(self._decoder_cells),
        use_new_attention=(self.params['attention_type'] == 'gnmt_v2'))
    else:
      attentive_decoder_cell = AttentionWrapper(
        cell=self._decoder_cells,
        attention_mechanism=attention_mechanism,
      )
    batch_size_tensor = tf.constant(self._batch_size)
    embedding_fn = lambda ids: tf.cast(
      tf.nn.embedding_lookup(self._dec_emb_w, ids),
      dtype=self.params['dtype'])
    #decoder = tf.contrib.seq2seq.BeamSearchDecoder(
    decoder = BeamSearchDecoder(
      cell=attentive_decoder_cell,
      embedding=embedding_fn,
      start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]),
      end_token=self.END_SYMBOL,
      initial_state=attentive_decoder_cell.zero_state(
        dtype=encoder_outputs.dtype,
        batch_size=batch_size_tensor * self._beam_width,
      ),
      beam_width=self._beam_width,
      output_layer=self._output_projection_layer,
      length_penalty_weight=self._length_penalty_weight
    )

    time_major = self.params.get("time_major", False)
    use_swap_memory = self.params.get("use_swap_memory", False)
    final_outputs, final_state, final_sequence_lengths = \
      tf.contrib.seq2seq.dynamic_decode(
      decoder=decoder,
      maximum_iterations=tf.reduce_max(enc_src_lengths) * 2,
      swap_memory=use_swap_memory,
      output_time_major=time_major,
    )

    return {'logits': final_outputs.predicted_ids[:, :, 0],
            'samples': [final_outputs.predicted_ids[:, :, 0]],
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths}
Exemplo n.º 2
0
  def _encode(self, input_dict):
    """
    Encodes data into representation
    :param input_dict: a Python dictionary.
    Must define:
      * src_inputs - a Tensor of shape [batch_size, time] or [time, batch_size]
                     (depending on time_major param)
      * src_lengths - a Tensor of shape [batch_size]
    :return: a Python dictionary with:
      * encoder_outputs - a Tensor of shape
                          [batch_size, time, representation_dim]
      or [time, batch_size, representation_dim]
      * encoder_state - a Tensor of shape [batch_size, dim]
      * src_lengths - (copy ref from input) a Tensor of shape [batch_size]
    """
    # TODO: make a separate level of config for cell_params?
    source_sequence = input_dict['source_tensors'][0]
    source_length = input_dict['source_tensors'][1]


    cell_params = copy.deepcopy(self.params)
    cell_params["num_units"] = self.params['encoder_cell_units']

    self._enc_emb_w = tf.get_variable(
      name="EncoderEmbeddingMatrix",
      shape=[self._src_vocab_size, self._src_emb_size],
      dtype=tf.float32
    )

    if self._mode == "train":
      dp_input_keep_prob = self.params['encoder_dp_input_keep_prob']
      dp_output_keep_prob = self.params['encoder_dp_output_keep_prob']
    else:
      dp_input_keep_prob = 1.0
      dp_output_keep_prob = 1.0

    self._encoder_cell_fw = create_rnn_cell(
      cell_type=self.params['encoder_cell_type'],
      cell_params=cell_params,
      num_layers=self.params['encoder_layers'],
      dp_input_keep_prob=dp_input_keep_prob,
      dp_output_keep_prob=dp_output_keep_prob,
      residual_connections=self.params['encoder_use_skip_connections'],
    )

    time_major = self.params.get("time_major", False)
    use_swap_memory = self.params.get("use_swap_memory", False)

    embedded_inputs = tf.cast(tf.nn.embedding_lookup(
      self.enc_emb_w,
      source_sequence,
    ), self.params['dtype'])

    encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
      cell=self._encoder_cell_fw,
      inputs=embedded_inputs,
      sequence_length=source_length,
      time_major=time_major,
      swap_memory=use_swap_memory,
      dtype=embedded_inputs.dtype,
    )
    return {'outputs': encoder_outputs,
            'state': encoder_state,
            'src_lengths': source_length,
            'encoder_input': source_sequence}
Exemplo n.º 3
0
  def _decode(self, input_dict):
    """
    Decodes representation into data
    :param input_dict: Python dictionary with inputs to decoder
    Must define:
      * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
                     or [time, batch_size, dim]
      * src_lengths - decoder input lengths Tensor of shape [batch_size]
      * tgt_inputs - Only during training. labels Tensor of the
                     shape [batch_size, time] or [time, batch_size]
      * tgt_lengths - Only during training. labels lengths
                      Tensor of the shape [batch_size]
    :return: a Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim]
                        or [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time]
                                 or [time, batch_size]
    """
    encoder_outputs = input_dict['encoder_output']['outputs']
    enc_src_lengths = input_dict['encoder_output']['src_lengths']
    tgt_inputs = input_dict['target_tensors'][0] if 'target_tensors' in \
                                                    input_dict else None
    tgt_lengths = input_dict['target_tensors'][1] if 'target_tensors' in \
                                                    input_dict else None

    self._dec_emb_w = tf.get_variable(
      name='DecoderEmbeddingMatrix',
      shape=[self._tgt_vocab_size, self._tgt_emb_size],
      dtype=tf.float32,
    )

    self._output_projection_layer = tf.layers.Dense(
      self._tgt_vocab_size, use_bias=False,
    )

    cell_params = copy.deepcopy(self.params)
    cell_params["num_units"] = self.params['decoder_cell_units']

    if self._mode == "train":
      dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
      dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
    else:
      dp_input_keep_prob = 1.0
      dp_output_keep_prob = 1.0

    if self.params['attention_type'].startswith('gnmt'):
      residual_connections = False
      wrap_to_multi_rnn = False
    else:
      residual_connections = self.params['decoder_use_skip_connections']
      wrap_to_multi_rnn = True

    self._decoder_cells = create_rnn_cell(
      cell_type=self.params['decoder_cell_type'],
      cell_params=cell_params,
      num_layers=self.params['decoder_layers'],
      dp_input_keep_prob=dp_input_keep_prob,
      dp_output_keep_prob=dp_output_keep_prob,
      residual_connections=residual_connections,
      wrap_to_multi_rnn=wrap_to_multi_rnn,
    )

    attention_mechanism = self._build_attention(
      encoder_outputs,
      enc_src_lengths,
    )
    if self.params['attention_type'].startswith('gnmt'):
      attention_cell = self._decoder_cells.pop(0)
      # attention_cell = tf.contrib.seq2seq.AttentionWrapper(
      attention_cell = AttentionWrapper(
        attention_cell,
        attention_mechanism=attention_mechanism,
        attention_layer_size=None,
        output_attention=False,
        name="gnmt_attention")
      attentive_decoder_cell = GNMTAttentionMultiCell(
        attention_cell, self._add_residual_wrapper(self._decoder_cells),
        use_new_attention=(self.params['attention_type'] == 'gnmt_v2'))
    else:
      # attentive_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
      attentive_decoder_cell = AttentionWrapper(
        cell=self._decoder_cells,
        attention_mechanism=attention_mechanism,
      )
    if self._mode == "train":
      input_vectors = tf.cast(tf.nn.embedding_lookup(self._dec_emb_w, tgt_inputs),
                              dtype=self.params['dtype'])
      helper = tf.contrib.seq2seq.TrainingHelper(
        inputs=input_vectors,
        sequence_length=tgt_lengths)
      decoder = tf.contrib.seq2seq.BasicDecoder(
        cell=attentive_decoder_cell,
        helper=helper,
        output_layer=self._output_projection_layer,
        initial_state=attentive_decoder_cell.zero_state(
          self._batch_size, dtype=encoder_outputs.dtype,
        ),
      )
    elif self._mode == "infer" or self._mode == "eval":
      embedding_fn = lambda ids: tf.cast(tf.nn.embedding_lookup(self._dec_emb_w, ids),
                              dtype=self.params['dtype'])
      helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
        embedding=embedding_fn,#self._dec_emb_w,
        start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL),
        end_token=self.END_SYMBOL)
      decoder = tf.contrib.seq2seq.BasicDecoder(
        cell=attentive_decoder_cell,
        helper=helper,
        initial_state=attentive_decoder_cell.zero_state(
          batch_size=self._batch_size, dtype=encoder_outputs.dtype,
        ),
        output_layer=self._output_projection_layer,
      )
    else:
      raise ValueError(
        "Unknown mode for decoder: {}".format(self._mode)
      )

    time_major = self.params.get("time_major", False)
    use_swap_memory = self.params.get("use_swap_memory", False)
    if self._mode == 'train':
      maximum_iterations = tf.reduce_max(tgt_lengths)
    else:
      maximum_iterations = tf.reduce_max(enc_src_lengths) * 2

    final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
      decoder=decoder,
      # impute_finished=False if self._decoder_type == "beam_search" else True,
      impute_finished=True,
      maximum_iterations=maximum_iterations,
      swap_memory=use_swap_memory,
      output_time_major=time_major,
    )

    return {'logits': final_outputs.rnn_output,
            'samples': [tf.argmax(final_outputs.rnn_output, axis=-1)],
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths}
Exemplo n.º 4
0
  def _encode(self, input_dict):
    source_sequence = input_dict['source_tensors'][0]
    source_length = input_dict['source_tensors'][1]
    self._enc_emb_w = tf.get_variable(
      name="EncoderEmbeddingMatrix",
      shape=[self._src_vocab_size, self._src_emb_size],
      dtype=tf.float32
    )

    if self.params['encoder_layers'] < 2:
      raise ValueError("GNMT encoder must have at least 2 layers")

    cell_params = copy.deepcopy(self.params)
    cell_params["num_units"] = self.params['encoder_cell_units']

    with tf.variable_scope("Level1FW"):
      self._encoder_l1_cell_fw = create_rnn_cell(
        cell_type=self.params['encoder_cell_type'],
        cell_params=cell_params,
        num_layers=1,
        dp_input_keep_prob=1.0,
        dp_output_keep_prob=1.0,
        residual_connections=False,
      )
    with tf.variable_scope("Level1BW"):
      self._encoder_l1_cell_bw = create_rnn_cell(
        cell_type=self.params['encoder_cell_type'],
        cell_params=cell_params,
        num_layers=1,
        dp_input_keep_prob=1.0,
        dp_output_keep_prob=1.0,
        residual_connections=False,
      )

    if self._mode == "train":
      dp_input_keep_prob = self.params['encoder_dp_input_keep_prob']
      dp_output_keep_prob = self.params['encoder_dp_output_keep_prob']
    else:
      dp_input_keep_prob = 1.0
      dp_output_keep_prob = 1.0

    with tf.variable_scope("UniDirLevel"):
      self._encoder_cells = create_rnn_cell(
        cell_type=self.params['encoder_cell_type'],
        cell_params=cell_params,
        num_layers=self.params['encoder_layers'] - 1,
        dp_input_keep_prob=dp_input_keep_prob,
        dp_output_keep_prob=dp_output_keep_prob,
        residual_connections=False,
        wrap_to_multi_rnn=False,
      )
      # add residual connections starting from the third layer
      for idx, cell in enumerate(self._encoder_cells):
        if idx > 0:
          self._encoder_cells[idx] = tf.contrib.rnn.ResidualWrapper(cell)

    time_major = self.params.get("time_major", False)
    use_swap_memory = self.params.get("use_swap_memory", False)
    embedded_inputs = tf.cast(tf.nn.embedding_lookup(
      self.enc_emb_w,
      source_sequence,
    ), self.params['dtype'])

    # first bi-directional layer
    _encoder_output, _ = tf.nn.bidirectional_dynamic_rnn(
      cell_fw=self._encoder_l1_cell_fw,
      cell_bw=self._encoder_l1_cell_bw,
      inputs=embedded_inputs,
      sequence_length=source_length,
      swap_memory=use_swap_memory,
      time_major=time_major,
      dtype=embedded_inputs.dtype,
    )
    encoder_l1_outputs = tf.concat(_encoder_output, 2)

    # stack of unidirectional layers
    encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
      cell=tf.contrib.rnn.MultiRNNCell(self._encoder_cells),
      inputs=encoder_l1_outputs,
      sequence_length=source_length,
      swap_memory=use_swap_memory,
      time_major = time_major,
      dtype=encoder_l1_outputs.dtype,
    )

    return {'outputs': encoder_outputs,
            'state': encoder_state,
            'src_lengths': source_length,
            'encoder_input': source_sequence}
Exemplo n.º 5
0
    def _encode(self, input_dict):
        """
    Encodes data into representation
    :param input_dict: a Python dictionary.
    Must define:
      * src_inputs - a Tensor of shape [batch_size, time] or [time, batch_size]
                     (depending on time_major param)
      * src_lengths - a Tensor of shape [batch_size]
    :return: a Python dictionary with:
      * encoder_outputs - a Tensor of shape
                          [batch_size, time, representation_dim]
      or [time, batch_size, representation_dim]
      * encoder_state - a Tensor of shape [batch_size, dim]
      * src_lengths - (copy ref from input) a Tensor of shape [batch_size]
    """
        # TODO: make a separate level of config for cell_params?
        source_sequence = input_dict['source_tensors'][0]
        source_length = input_dict['source_tensors'][1]

        cell_params = copy.deepcopy(self.params)
        cell_params["num_units"] = self.params['encoder_cell_units']

        self._enc_emb_w = tf.get_variable(
            name="EncoderEmbeddingMatrix",
            shape=[self._src_vocab_size, self._src_emb_size],
            dtype=tf.float32)

        if self._mode == "train":
            dp_input_keep_prob = self.params['encoder_dp_input_keep_prob']
            dp_output_keep_prob = self.params['encoder_dp_output_keep_prob']
        else:
            dp_input_keep_prob = 1.0
            dp_output_keep_prob = 1.0

        self._encoder_cell_fw = create_rnn_cell(
            cell_type=self.params['encoder_cell_type'],
            cell_params=cell_params,
            num_layers=self.params['encoder_layers'],
            dp_input_keep_prob=dp_input_keep_prob,
            dp_output_keep_prob=dp_output_keep_prob,
            residual_connections=self.params['encoder_use_skip_connections'],
        )

        time_major = self.params.get("time_major", False)
        use_swap_memory = self.params.get("use_swap_memory", False)

        embedded_inputs = tf.cast(
            tf.nn.embedding_lookup(
                self.enc_emb_w,
                source_sequence,
            ), self.params['dtype'])

        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
            cell=self._encoder_cell_fw,
            inputs=embedded_inputs,
            sequence_length=source_length,
            time_major=time_major,
            swap_memory=use_swap_memory,
            dtype=embedded_inputs.dtype,
        )
        return {
            'outputs': encoder_outputs,
            'state': encoder_state,
            'src_lengths': source_length,
            'encoder_input': source_sequence
        }
Exemplo n.º 6
0
    def _encode(self, input_dict):
        source_sequence = input_dict['source_tensors'][0]
        source_length = input_dict['source_tensors'][1]
        self._enc_emb_w = tf.get_variable(
            name="EncoderEmbeddingMatrix",
            shape=[self._src_vocab_size, self._src_emb_size],
            dtype=tf.float32)

        if self.params['encoder_layers'] < 2:
            raise ValueError("GNMT encoder must have at least 2 layers")

        cell_params = copy.deepcopy(self.params)
        cell_params["num_units"] = self.params['encoder_cell_units']

        with tf.variable_scope("Level1FW"):
            self._encoder_l1_cell_fw = create_rnn_cell(
                cell_type=self.params['encoder_cell_type'],
                cell_params=cell_params,
                num_layers=1,
                dp_input_keep_prob=1.0,
                dp_output_keep_prob=1.0,
                residual_connections=False,
            )
        with tf.variable_scope("Level1BW"):
            self._encoder_l1_cell_bw = create_rnn_cell(
                cell_type=self.params['encoder_cell_type'],
                cell_params=cell_params,
                num_layers=1,
                dp_input_keep_prob=1.0,
                dp_output_keep_prob=1.0,
                residual_connections=False,
            )

        if self._mode == "train":
            dp_input_keep_prob = self.params['encoder_dp_input_keep_prob']
            dp_output_keep_prob = self.params['encoder_dp_output_keep_prob']
        else:
            dp_input_keep_prob = 1.0
            dp_output_keep_prob = 1.0

        with tf.variable_scope("UniDirLevel"):
            self._encoder_cells = create_rnn_cell(
                cell_type=self.params['encoder_cell_type'],
                cell_params=cell_params,
                num_layers=self.params['encoder_layers'] - 1,
                dp_input_keep_prob=dp_input_keep_prob,
                dp_output_keep_prob=dp_output_keep_prob,
                residual_connections=False,
                wrap_to_multi_rnn=False,
            )
            # add residual connections starting from the third layer
            for idx, cell in enumerate(self._encoder_cells):
                if idx > 0:
                    self._encoder_cells[idx] = tf.contrib.rnn.ResidualWrapper(
                        cell)

        time_major = self.params.get("time_major", False)
        use_swap_memory = self.params.get("use_swap_memory", False)
        embedded_inputs = tf.cast(
            tf.nn.embedding_lookup(
                self.enc_emb_w,
                source_sequence,
            ), self.params['dtype'])

        # first bi-directional layer
        _encoder_output, _ = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=self._encoder_l1_cell_fw,
            cell_bw=self._encoder_l1_cell_bw,
            inputs=embedded_inputs,
            sequence_length=source_length,
            swap_memory=use_swap_memory,
            time_major=time_major,
            dtype=embedded_inputs.dtype,
        )
        encoder_l1_outputs = tf.concat(_encoder_output, 2)

        # stack of unidirectional layers
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
            cell=tf.contrib.rnn.MultiRNNCell(self._encoder_cells),
            inputs=encoder_l1_outputs,
            sequence_length=source_length,
            swap_memory=use_swap_memory,
            time_major=time_major,
            dtype=encoder_l1_outputs.dtype,
        )

        return {
            'outputs': encoder_outputs,
            'state': encoder_state,
            'src_lengths': source_length,
            'encoder_input': source_sequence
        }
Exemplo n.º 7
0
    def _decode(self, input_dict):
        """
    Decodes representation into data
    :param input_dict: Python dictionary with inputs to decoder
    Must define:
      * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
                     or [time, batch_size, dim]
      * src_lengths - decoder input lengths Tensor of shape [batch_size]
    Does not need tgt_inputs and tgt_lengths
    :return: a Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim] or
                        [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time] or
                                 [time, batch_size]
    """
        encoder_outputs = input_dict['encoder_output']['outputs']
        enc_src_lengths = input_dict['encoder_output']['src_lengths']

        self._dec_emb_w = tf.get_variable(
            name='DecoderEmbeddingMatrix',
            shape=[self._tgt_vocab_size, self._tgt_emb_size],
            dtype=tf.float32)

        self._output_projection_layer = tf.layers.Dense(
            self._tgt_vocab_size,
            use_bias=False,
        )

        cell_params = copy.deepcopy(self.params)
        cell_params["num_units"] = self.params['decoder_cell_units']

        if self._mode == "train":
            dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
            dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
        else:
            dp_input_keep_prob = 1.0
            dp_output_keep_prob = 1.0

        if self.params['attention_type'].startswith('gnmt'):
            residual_connections = False
            wrap_to_multi_rnn = False
        else:
            residual_connections = self.params['decoder_use_skip_connections']
            wrap_to_multi_rnn = True

        self._decoder_cells = create_rnn_cell(
            cell_type=self.params['decoder_cell_type'],
            cell_params=cell_params,
            num_layers=self.params['decoder_layers'],
            dp_input_keep_prob=dp_input_keep_prob,
            dp_output_keep_prob=dp_output_keep_prob,
            residual_connections=residual_connections,
            wrap_to_multi_rnn=wrap_to_multi_rnn,
        )

        tiled_enc_outputs = tf.contrib.seq2seq.tile_batch(
            encoder_outputs,
            multiplier=self._beam_width,
        )
        tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch(
            enc_src_lengths,
            multiplier=self._beam_width,
        )
        attention_mechanism = self._build_attention(
            tiled_enc_outputs,
            tiled_enc_src_lengths,
        )

        if self.params['attention_type'].startswith('gnmt'):
            attention_cell = self._decoder_cells.pop(0)
            attention_cell = AttentionWrapper(
                attention_cell,
                attention_mechanism=attention_mechanism,
                attention_layer_size=None,  # don't use attention layer.
                output_attention=False,
                name="gnmt_attention")
            attentive_decoder_cell = GNMTAttentionMultiCell(
                attention_cell,
                self._add_residual_wrapper(self._decoder_cells),
                use_new_attention=(self.params['attention_type'] == 'gnmt_v2'))
        else:
            attentive_decoder_cell = AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanism,
            )
        batch_size_tensor = tf.constant(self._batch_size)
        embedding_fn = lambda ids: tf.cast(tf.nn.embedding_lookup(
            self._dec_emb_w, ids),
                                           dtype=self.params['dtype'])
        #decoder = tf.contrib.seq2seq.BeamSearchDecoder(
        decoder = BeamSearchDecoder(
            cell=attentive_decoder_cell,
            embedding=embedding_fn,
            start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]),
            end_token=self.END_SYMBOL,
            initial_state=attentive_decoder_cell.zero_state(
                dtype=encoder_outputs.dtype,
                batch_size=batch_size_tensor * self._beam_width,
            ),
            beam_width=self._beam_width,
            output_layer=self._output_projection_layer,
            length_penalty_weight=self._length_penalty_weight)

        time_major = self.params.get("time_major", False)
        use_swap_memory = self.params.get("use_swap_memory", False)
        final_outputs, final_state, final_sequence_lengths = \
          tf.contrib.seq2seq.dynamic_decode(
          decoder=decoder,
          maximum_iterations=tf.reduce_max(enc_src_lengths) * 2,
          swap_memory=use_swap_memory,
          output_time_major=time_major,
        )

        return {
            'logits': final_outputs.predicted_ids[:, :, 0],
            'samples': [final_outputs.predicted_ids[:, :, 0]],
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths
        }
Exemplo n.º 8
0
    def _decode(self, input_dict):
        """
    Decodes representation into data
    :param input_dict: Python dictionary with inputs to decoder
    Must define:
      * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
                     or [time, batch_size, dim]
      * src_lengths - decoder input lengths Tensor of shape [batch_size]
      * tgt_inputs - Only during training. labels Tensor of the
                     shape [batch_size, time] or [time, batch_size]
      * tgt_lengths - Only during training. labels lengths
                      Tensor of the shape [batch_size]
    :return: a Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim]
                        or [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time]
                                 or [time, batch_size]
    """
        encoder_outputs = input_dict['encoder_output']['outputs']
        enc_src_lengths = input_dict['encoder_output']['src_lengths']
        tgt_inputs = input_dict['target_tensors'][0] if 'target_tensors' in \
                                                        input_dict else None
        tgt_lengths = input_dict['target_tensors'][1] if 'target_tensors' in \
                                                        input_dict else None

        self._dec_emb_w = tf.get_variable(
            name='DecoderEmbeddingMatrix',
            shape=[self._tgt_vocab_size, self._tgt_emb_size],
            dtype=tf.float32,
        )

        self._output_projection_layer = tf.layers.Dense(
            self._tgt_vocab_size,
            use_bias=False,
        )

        cell_params = copy.deepcopy(self.params)
        cell_params["num_units"] = self.params['decoder_cell_units']

        if self._mode == "train":
            dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
            dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
        else:
            dp_input_keep_prob = 1.0
            dp_output_keep_prob = 1.0

        if self.params['attention_type'].startswith('gnmt'):
            residual_connections = False
            wrap_to_multi_rnn = False
        else:
            residual_connections = self.params['decoder_use_skip_connections']
            wrap_to_multi_rnn = True

        self._decoder_cells = create_rnn_cell(
            cell_type=self.params['decoder_cell_type'],
            cell_params=cell_params,
            num_layers=self.params['decoder_layers'],
            dp_input_keep_prob=dp_input_keep_prob,
            dp_output_keep_prob=dp_output_keep_prob,
            residual_connections=residual_connections,
            wrap_to_multi_rnn=wrap_to_multi_rnn,
        )

        attention_mechanism = self._build_attention(
            encoder_outputs,
            enc_src_lengths,
        )
        if self.params['attention_type'].startswith('gnmt'):
            attention_cell = self._decoder_cells.pop(0)
            # attention_cell = tf.contrib.seq2seq.AttentionWrapper(
            attention_cell = AttentionWrapper(
                attention_cell,
                attention_mechanism=attention_mechanism,
                attention_layer_size=None,
                output_attention=False,
                name="gnmt_attention")
            attentive_decoder_cell = GNMTAttentionMultiCell(
                attention_cell,
                self._add_residual_wrapper(self._decoder_cells),
                use_new_attention=(self.params['attention_type'] == 'gnmt_v2'))
        else:
            # attentive_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            attentive_decoder_cell = AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanism,
            )
        if self._mode == "train":
            input_vectors = tf.cast(tf.nn.embedding_lookup(
                self._dec_emb_w, tgt_inputs),
                                    dtype=self.params['dtype'])
            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=input_vectors, sequence_length=tgt_lengths)
            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=attentive_decoder_cell,
                helper=helper,
                output_layer=self._output_projection_layer,
                initial_state=attentive_decoder_cell.zero_state(
                    self._batch_size,
                    dtype=encoder_outputs.dtype,
                ),
            )
        elif self._mode == "infer" or self._mode == "eval":
            embedding_fn = lambda ids: tf.cast(tf.nn.embedding_lookup(
                self._dec_emb_w, ids),
                                               dtype=self.params['dtype'])
            helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                embedding=embedding_fn,  #self._dec_emb_w,
                start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL),
                end_token=self.END_SYMBOL)
            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=attentive_decoder_cell,
                helper=helper,
                initial_state=attentive_decoder_cell.zero_state(
                    batch_size=self._batch_size,
                    dtype=encoder_outputs.dtype,
                ),
                output_layer=self._output_projection_layer,
            )
        else:
            raise ValueError("Unknown mode for decoder: {}".format(self._mode))

        time_major = self.params.get("time_major", False)
        use_swap_memory = self.params.get("use_swap_memory", False)
        if self._mode == 'train':
            maximum_iterations = tf.reduce_max(tgt_lengths)
        else:
            maximum_iterations = tf.reduce_max(enc_src_lengths) * 2

        final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
            decoder=decoder,
            # impute_finished=False if self._decoder_type == "beam_search" else True,
            impute_finished=True,
            maximum_iterations=maximum_iterations,
            swap_memory=use_swap_memory,
            output_time_major=time_major,
        )

        return {
            'logits': final_outputs.rnn_output,
            'samples': [tf.argmax(final_outputs.rnn_output, axis=-1)],
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths
        }