示例#1
0
    def _encode(self, input_dict):
        """
    Encodes data into representation
    :param input_dict: a Python dictionary.
    Must define:
      * src_inputs - a Tensor of shape [batch_size, time] or [time, batch_size]
                     (depending on time_major param)
      * src_lengths - a Tensor of shape [batch_size]
    :return: a Python dictionary with:
      * encoder_outputs - a Tensor of shape
                          [batch_size, time, representation_dim]
      or [time, batch_size, representation_dim]
      * encoder_state - a Tensor of shape [batch_size, dim]
      * src_lengths - (copy ref from input) a Tensor of shape [batch_size]
    """
        # TODO: make a separate level of config for cell_params?
        cell_params = copy.deepcopy(self.params)
        cell_params["num_units"] = self.params['encoder_cell_units']

        self._enc_emb_w = tf.get_variable(
            name="EncoderEmbeddingMatrix",
            shape=[self._src_vocab_size, self._src_emb_size],
            dtype=tf.float32)

        if self._mode == "train":
            dp_input_keep_prob = self.params['encoder_dp_input_keep_prob']
            dp_output_keep_prob = self.params['encoder_dp_output_keep_prob']
        else:
            dp_input_keep_prob = 1.0
            dp_output_keep_prob = 1.0

        self._encoder_cell_fw = create_rnn_cell(
            cell_type=self.params['encoder_cell_type'],
            cell_params=cell_params,
            num_layers=self.params['encoder_layers'],
            dp_input_keep_prob=dp_input_keep_prob,
            dp_output_keep_prob=dp_output_keep_prob,
            residual_connections=self.params['encoder_use_skip_connections'],
        )

        time_major = self.params.get("time_major", False)
        use_swap_memory = self.params.get("use_swap_memory", False)

        embedded_inputs = tf.cast(
            tf.nn.embedding_lookup(
                self.enc_emb_w,
                input_dict['src_sequence'],
            ), self.params['dtype'])

        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
            cell=self._encoder_cell_fw,
            inputs=embedded_inputs,
            sequence_length=input_dict['src_length'],
            time_major=time_major,
            swap_memory=use_swap_memory,
            dtype=embedded_inputs.dtype,
        )
        return {
            'outputs': encoder_outputs,
            'state': encoder_state,
            'src_lengths': input_dict['src_length'],
            'encoder_input': input_dict['src_sequence']
        }
示例#2
0
  def _decode(self, input_dict):
    """
    Decodes representation into data
    :param input_dict: Python dictionary with inputs to decoder
    Must define:
      * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
                     or [time, batch_size, dim]
      * src_lengths - decoder input lengths Tensor of shape [batch_size]
    Does not need tgt_inputs and tgt_lengths
    :return: a Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim] or
                        [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time] or
                                 [time, batch_size]
    """
    encoder_outputs = input_dict['encoder_output']['outputs']
    enc_src_lengths = input_dict['encoder_output']['src_lengths']

    self._dec_emb_w = tf.get_variable(
      name='DecoderEmbeddingMatrix',
      shape=[self._tgt_vocab_size, self._tgt_emb_size],
      dtype=tf.float32
    )

    self._output_projection_layer = tf.layers.Dense(
      self._tgt_vocab_size, use_bias=False,
    )

    cell_params = copy.deepcopy(self.params)
    cell_params["num_units"] = self.params['decoder_cell_units']

    if self._mode == "train":
      dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
      dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
    else:
      dp_input_keep_prob = 1.0
      dp_output_keep_prob = 1.0

    if self.params['attention_type'].startswith('gnmt'):
      residual_connections = False
      wrap_to_multi_rnn = False
    else:
      residual_connections = self.params['decoder_use_skip_connections']
      wrap_to_multi_rnn = True

    self._decoder_cells = create_rnn_cell(
      cell_type=self.params['decoder_cell_type'],
      cell_params=cell_params,
      num_layers=self.params['decoder_layers'],
      dp_input_keep_prob=dp_input_keep_prob,
      dp_output_keep_prob=dp_output_keep_prob,
      residual_connections=residual_connections,
      wrap_to_multi_rnn=wrap_to_multi_rnn,
    )

    tiled_enc_outputs = tf.contrib.seq2seq.tile_batch(
      encoder_outputs,
      multiplier=self._beam_width,
    )
    tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch(
      enc_src_lengths,
      multiplier=self._beam_width,
    )
    attention_mechanism = self._build_attention(
      tiled_enc_outputs,
      tiled_enc_src_lengths,
    )

    if self.params['attention_type'].startswith('gnmt'):
      attention_cell = self._decoder_cells.pop(0)
      attention_cell = AttentionWrapper(
        attention_cell,
        attention_mechanism=attention_mechanism,
        attention_layer_size=None,  # don't use attention layer.
        output_attention=False,
        name="gnmt_attention")
      attentive_decoder_cell = GNMTAttentionMultiCell(
        attention_cell, self._add_residual_wrapper(self._decoder_cells),
        use_new_attention=(self.params['attention_type'] == 'gnmt_v2'))
    else:
      attentive_decoder_cell = AttentionWrapper(
        cell=self._decoder_cells,
        attention_mechanism=attention_mechanism,
      )
    batch_size_tensor = tf.constant(self._batch_size)
    embedding_fn = lambda ids: tf.cast(
      tf.nn.embedding_lookup(self._dec_emb_w, ids),
      dtype=self.params['dtype'])
    #decoder = tf.contrib.seq2seq.BeamSearchDecoder(
    decoder = BeamSearchDecoder(
      cell=attentive_decoder_cell,
      embedding=embedding_fn,
      start_tokens=tf.tile([self.GO_SYMBOL], [self._batch_size]),
      end_token=self.END_SYMBOL,
      initial_state=attentive_decoder_cell.zero_state(
        dtype=encoder_outputs.dtype,
        batch_size=batch_size_tensor * self._beam_width,
      ),
      beam_width=self._beam_width,
      output_layer=self._output_projection_layer,
      length_penalty_weight=self._length_penalty_weight
    )

    time_major = self.params.get("time_major", False)
    use_swap_memory = self.params.get("use_swap_memory", False)
    final_outputs, final_state, final_sequence_lengths = \
      tf.contrib.seq2seq.dynamic_decode(
      decoder=decoder,
      maximum_iterations=tf.reduce_max(enc_src_lengths) * 2,
      swap_memory=use_swap_memory,
      output_time_major=time_major,
    )

    return {'logits': final_outputs.predicted_ids[:, :, 0],
            'samples': final_outputs.predicted_ids[:, :, 0],
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths}
示例#3
0
    def _encode(self, input_dict):

        self._enc_emb_w = tf.get_variable(
            name="EncoderEmbeddingMatrix",
            shape=[self._src_vocab_size, self._src_emb_size],
            dtype=tf.float32)

        if self.params['encoder_layers'] < 2:
            raise ValueError("GNMT encoder must have at least 2 layers")

        cell_params = copy.deepcopy(self.params)
        cell_params["num_units"] = self.params['encoder_cell_units']

        with tf.variable_scope("Level1FW"):
            self._encoder_l1_cell_fw = create_rnn_cell(
                cell_type=self.params['encoder_cell_type'],
                cell_params=cell_params,
                num_layers=1,
                dp_input_keep_prob=1.0,
                dp_output_keep_prob=1.0,
                residual_connections=False,
            )
        with tf.variable_scope("Level1BW"):
            self._encoder_l1_cell_bw = create_rnn_cell(
                cell_type=self.params['encoder_cell_type'],
                cell_params=cell_params,
                num_layers=1,
                dp_input_keep_prob=1.0,
                dp_output_keep_prob=1.0,
                residual_connections=False,
            )

        if self._mode == "train":
            dp_input_keep_prob = self.params['encoder_dp_input_keep_prob']
            dp_output_keep_prob = self.params['encoder_dp_output_keep_prob']
        else:
            dp_input_keep_prob = 1.0
            dp_output_keep_prob = 1.0

        with tf.variable_scope("UniDirLevel"):
            self._encoder_cells = create_rnn_cell(
                cell_type=self.params['encoder_cell_type'],
                cell_params=cell_params,
                num_layers=self.params['encoder_layers'] - 1,
                dp_input_keep_prob=dp_input_keep_prob,
                dp_output_keep_prob=dp_output_keep_prob,
                residual_connections=False,
                wrap_to_multi_rnn=False,
            )
            # add residual connections starting from the third layer
            for idx, cell in enumerate(self._encoder_cells):
                if idx > 0:
                    self._encoder_cells[idx] = tf.contrib.rnn.ResidualWrapper(
                        cell)

        time_major = self.params.get("time_major", False)
        use_swap_memory = self.params.get("use_swap_memory", False)
        embedded_inputs = tf.cast(
            tf.nn.embedding_lookup(
                self.enc_emb_w,
                input_dict['src_sequence'],
            ), self.params['dtype'])

        # first bi-directional layer
        _encoder_output, _ = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=self._encoder_l1_cell_fw,
            cell_bw=self._encoder_l1_cell_bw,
            inputs=embedded_inputs,
            sequence_length=input_dict['src_length'],
            swap_memory=use_swap_memory,
            time_major=time_major,
            dtype=embedded_inputs.dtype,
        )
        encoder_l1_outputs = tf.concat(_encoder_output, 2)

        # stack of unidirectional layers
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
            cell=tf.contrib.rnn.MultiRNNCell(self._encoder_cells),
            inputs=encoder_l1_outputs,
            sequence_length=input_dict['src_length'],
            swap_memory=use_swap_memory,
            time_major=time_major,
            dtype=encoder_l1_outputs.dtype,
        )

        return {
            'outputs': encoder_outputs,
            'state': encoder_state,
            'src_lengths': input_dict['src_length'],
            'encoder_input': input_dict['src_sequence']
        }
示例#4
0
  def _decode(self, input_dict):
    """
    Decodes representation into data
    :param input_dict: Python dictionary with inputs to decoder
    Must define:
      * src_inputs - decoder input Tensor of shape [batch_size, time, dim]
                     or [time, batch_size, dim]
      * src_lengths - decoder input lengths Tensor of shape [batch_size]
      * tgt_inputs - Only during training. labels Tensor of the
                     shape [batch_size, time] or [time, batch_size]
      * tgt_lengths - Only during training. labels lengths
                      Tensor of the shape [batch_size]
    :return: a Python dictionary with:
      * final_outputs - tensor of shape [batch_size, time, dim]
                        or [time, batch_size, dim]
      * final_state - tensor with decoder final state
      * final_sequence_lengths - tensor of shape [batch_size, time]
                                 or [time, batch_size]
    """
    encoder_outputs = input_dict['encoder_output']['outputs']
    enc_src_lengths = input_dict['encoder_output']['src_lengths']
    tgt_inputs = input_dict['tgt_sequence']
    tgt_lengths = input_dict['tgt_length']

    self._dec_emb_w = tf.get_variable(
      name='DecoderEmbeddingMatrix',
      shape=[self._tgt_vocab_size, self._tgt_emb_size],
      dtype=tf.float32,
    )

    self._output_projection_layer = tf.layers.Dense(
      self._tgt_vocab_size, use_bias=False,
    )

    cell_params = copy.deepcopy(self.params)
    cell_params["num_units"] = self.params['decoder_cell_units']

    if self._mode == "train":
      dp_input_keep_prob = self.params['decoder_dp_input_keep_prob']
      dp_output_keep_prob = self.params['decoder_dp_output_keep_prob']
    else:
      dp_input_keep_prob = 1.0
      dp_output_keep_prob = 1.0

    if self.params['attention_type'].startswith('gnmt'):
      residual_connections = False
      wrap_to_multi_rnn = False
    else:
      residual_connections = self.params['decoder_use_skip_connections']
      wrap_to_multi_rnn = True

    self._decoder_cells = create_rnn_cell(
      cell_type=self.params['decoder_cell_type'],
      cell_params=cell_params,
      num_layers=self.params['decoder_layers'],
      dp_input_keep_prob=dp_input_keep_prob,
      dp_output_keep_prob=dp_output_keep_prob,
      residual_connections=residual_connections,
      wrap_to_multi_rnn=wrap_to_multi_rnn,
    )

    attention_mechanism = self._build_attention(
      encoder_outputs,
      enc_src_lengths,
    )
    if self.params['attention_type'].startswith('gnmt'):
      attention_cell = self._decoder_cells.pop(0)
      # attention_cell = tf.contrib.seq2seq.AttentionWrapper(
      attention_cell = AttentionWrapper(
        attention_cell,
        attention_mechanism=attention_mechanism,
        attention_layer_size=None,
        output_attention=False,
        name="gnmt_attention")
      attentive_decoder_cell = GNMTAttentionMultiCell(
        attention_cell, self._add_residual_wrapper(self._decoder_cells),
        use_new_attention=(self.params['attention_type'] == 'gnmt_v2'))
    else:
      # attentive_decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
      attentive_decoder_cell = AttentionWrapper(
        cell=self._decoder_cells,
        attention_mechanism=attention_mechanism,
      )
    if self._mode == "train":
      input_vectors = tf.cast(tf.nn.embedding_lookup(self._dec_emb_w, tgt_inputs),
                              dtype=self.params['dtype'])
      helper = tf.contrib.seq2seq.TrainingHelper(
        inputs=input_vectors,
        sequence_length=tgt_lengths)
      decoder = tf.contrib.seq2seq.BasicDecoder(
        cell=attentive_decoder_cell,
        helper=helper,
        output_layer=self._output_projection_layer,
        initial_state=attentive_decoder_cell.zero_state(
          self._batch_size, dtype=encoder_outputs.dtype,
        ),
      )
    elif self._mode == "infer" or self._mode == "eval":
      embedding_fn = lambda ids: tf.cast(tf.nn.embedding_lookup(self._dec_emb_w, ids),
                              dtype=self.params['dtype'])
      helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
        embedding=embedding_fn,#self._dec_emb_w,
        start_tokens=tf.fill([self._batch_size], self.GO_SYMBOL),
        end_token=self.END_SYMBOL)
      decoder = tf.contrib.seq2seq.BasicDecoder(
        cell=attentive_decoder_cell,
        helper=helper,
        initial_state=attentive_decoder_cell.zero_state(
          batch_size=self._batch_size, dtype=encoder_outputs.dtype,
        ),
        output_layer=self._output_projection_layer,
      )
    else:
      raise ValueError(
        "Unknown mode for decoder: {}".format(self._mode)
      )

    time_major = self.params.get("time_major", False)
    use_swap_memory = self.params.get("use_swap_memory", False)
    if self._mode == 'train':
      maximum_iterations = tf.reduce_max(tgt_lengths)
    else:
      maximum_iterations = tf.reduce_max(enc_src_lengths) * 2

    final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
      decoder=decoder,
      # impute_finished=False if self._decoder_type == "beam_search" else True,
      impute_finished=True,
      maximum_iterations=maximum_iterations,
      swap_memory=use_swap_memory,
      output_time_major=time_major,
    )

    return {'logits': final_outputs.rnn_output,
            'samples': tf.argmax(final_outputs.rnn_output, axis=-1),
            'final_state': final_state,
            'final_sequence_lengths': final_sequence_lengths}