Example #1
0
  def _build_individual_encoder_layers(self, bi_encoder_outputs,
                                       num_uni_layers, dtype, hparams):
    """Run each of the encoder layer separately, not used in general seq2seq."""
    uni_cell_lists = model_helper._cell_list(
        unit_type=hparams.unit_type,
        num_units=hparams.num_units,
        num_layers=num_uni_layers,
        num_residual_layers=self.num_encoder_residual_layers,
        forget_bias=hparams.forget_bias,
        dropout=hparams.dropout,
        num_gpus=self.num_gpus,
        base_gpu=1,
        mode=self.mode,
        single_cell_fn=self.single_cell_fn)

    encoder_inp = bi_encoder_outputs
    encoder_states = []
    self.encoder_state_list = [bi_encoder_outputs[:, :, :hparams.num_units],
                               bi_encoder_outputs[:, :, hparams.num_units:]]
    with tf.variable_scope("rnn/multi_rnn_cell"):
      for i, cell in enumerate(uni_cell_lists):
        with tf.variable_scope("cell_%d" % i) as scope:
          encoder_inp, encoder_state = tf.nn.dynamic_rnn(
              cell,
              encoder_inp,
              dtype=dtype,
              sequence_length=self.iterator.source_sequence_length,
              time_major=self.time_major,
              scope=scope)
          encoder_states.append(encoder_state)
          self.encoder_state_list.append(encoder_inp)

    encoder_state = tuple(encoder_states)
    encoder_outputs = self.encoder_state_list[-1]
    return encoder_state, encoder_outputs
Example #2
0
  def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                          source_sequence_length):
    """Build a RNN cell with GNMT attention architecture."""
    # Standard attention
    if not self.is_gnmt_attention:
      return super(GNMTModel, self)._build_decoder_cell(
          hparams, encoder_outputs, encoder_state, source_sequence_length)

    # GNMT attention
    attention_option = hparams.attention
    attention_architecture = hparams.attention_architecture
    num_units = hparams.num_units
    infer_mode = hparams.infer_mode

    dtype = self.dtype

    if self.time_major:
      memory = tf.transpose(encoder_outputs, [1, 0, 2])
    else:
      memory = encoder_outputs

    if (self.mode == tf.contrib.learn.ModeKeys.INFER and
        infer_mode == "beam_search"):
      memory, source_sequence_length, encoder_state, batch_size = (
          self._prepare_beam_search_decoder_inputs(
              hparams.beam_width, memory, source_sequence_length,
              encoder_state))
    else:
      batch_size = self.batch_size

    attention_mechanism = self.attention_mechanism_fn(
        attention_option, num_units, memory, source_sequence_length, self.mode)

    cell_list = model_helper._cell_list(  # pylint: disable=protected-access
        unit_type=hparams.unit_type,
        num_units=num_units,
        num_layers=self.num_decoder_layers,
        num_residual_layers=self.num_decoder_residual_layers,
        forget_bias=hparams.forget_bias,
        dropout=hparams.dropout,
        mode=self.mode,
        single_cell_fn=self.single_cell_fn,
        residual_fn=gnmt_residual_fn,
        global_step=self.global_step)

    # Only wrap the bottom layer with the attention mechanism.
    attention_cell = cell_list.pop(0)

    # Only generate alignment in greedy INFER mode.
    alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
                         infer_mode != "beam_search")
    attention_cell = attention_utils.AttentionWrapper(
        attention_cell,
        attention_mechanism,
        attention_layer_size=None,  # don't use attention layer.
        output_attention=False,
        alignment_history=alignment_history,
        name="attention")

    if attention_architecture == "gnmt":
      cell = GNMTAttentionMultiCell(
          attention_cell, cell_list)
    elif attention_architecture == "gnmt_v2":
      cell = GNMTAttentionMultiCell(
          attention_cell, cell_list, use_new_attention=True)
    else:
      raise ValueError(
          "Unknown attention_architecture %s" % attention_architecture)

    if hparams.pass_hidden_state:
      decoder_initial_state = tuple(
          zs.clone(cell_state=es)
          if isinstance(zs, attention_utils.AttentionWrapperState) else es
          for zs, es in zip(
              cell.zero_state(batch_size, dtype), encoder_state))
    else:
      decoder_initial_state = cell.zero_state(batch_size, dtype)

    return cell, decoder_initial_state
    def __init__(self,
                 inputs,
                 vocab_table,
                 reverse_vocab_table,
                 vocab_size,
                 gen_vocab_size,
                 embed_size,
                 num_units,
                 lr,
                 mode,
                 dropout,
                 grad_clip=None):
        self.mode = mode
        self.source = inputs.source
        self.source_length = inputs.source_length
        self.vocab_table = vocab_table
        self.reverse_vocab_table = reverse_vocab_table
        self.tgt_sos_id = tf.cast(self.vocab_table.lookup(tf.constant(SOS)),
                                  tf.int32)
        self.tgt_eos_id = tf.cast(self.vocab_table.lookup(tf.constant(EOS)),
                                  tf.int32)
        if self.mode == 'TRAIN' or self.mode == 'EVAL':
            self.target_input = inputs.target_input
            self.target_output = inputs.target_output
            self.target_length = inputs.target_length
            self.dropout = dropout
        elif self.mode == 'INFER':
            self.dropout = 0
        else:
            raise NotImplementedError

        self.num_units = num_units

        self.lr = lr
        self.grad_clip = None

        batch_size = tf.shape(self.source)[0]

        with tf.variable_scope("Model", reuse=tf.AUTO_REUSE) as scope:

            with tf.variable_scope("Embedding") as scope:
                self.embedding_matrix = tf.get_variable(
                    "shared_embedding_matrix", [vocab_size, embed_size],
                    dtype=tf.float32)

            with tf.variable_scope("Encoder") as scope:
                self.encoder_emb_inp = tf.nn.embedding_lookup(
                    self.embedding_matrix, self.source)
                bi_inputs = self.encoder_emb_inp
                fw_cells = _cell_list(num_units,
                                      num_layers=2,
                                      dropout=self.dropout)
                bw_cells = _cell_list(num_units,
                                      num_layers=2,
                                      dropout=self.dropout)
                # bi_outputs: batch_size * [L , 2*num_units]
                # bi_fw_state num_layers * [batch_size , num_units]
                # bi_bw_state num_layers * [batch_size ,num_units]
                bi_outputs, bi_fw_state, bi_bw_state = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
                    cells_fw=fw_cells,
                    cells_bw=bw_cells,
                    inputs=bi_inputs,
                    sequence_length=self.source_length,
                    dtype=tf.float32)

            with tf.variable_scope("Decoder") as scope:

                self.encoder_final_state = tuple([
                    tf.concat((bi_fw_state[i], bi_bw_state[i]), axis=-1)
                    for i in range(2)
                ])

                self.decoder_cell = tf.contrib.rnn.MultiRNNCell(
                    _cell_list(num_units * 2, num_layers=2, dropout=0))
                self.attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                    num_units=num_units * 2,
                    memory=bi_outputs,
                    memory_sequence_length=self.source_length)
                self.atten_cell = tf.contrib.seq2seq.AttentionWrapper(
                    cell=self.decoder_cell,
                    attention_mechanism=self.attention_mechanism,
                    attention_layer_size=num_units)

                atten_zero_state = self.atten_cell.zero_state(
                    batch_size=batch_size, dtype=tf.float32)
                self.decoder_initial_state = atten_zero_state.clone(
                    cell_state=self.encoder_final_state)

                #CopyNet
                with tf.variable_scope("CopyNet") as scope:

                    self.copy_cell = copynet.CopyNetWrapper(
                        self.atten_cell, bi_outputs, self.source, vocab_size,
                        gen_vocab_size)  #,encoder_state_size=num_units*2)

                    copy_zero_state = self.copy_cell.zero_state(
                        batch_size=batch_size, dtype=tf.float32)
                    self.decoder_initial_state = copy_zero_state.clone(
                        cell_state=self.decoder_initial_state)

                    self.final_cell = self.copy_cell
                    self.output_layer = None
                #CopyNet end

                #self.output_layer = tf.layers.Dense(vocab_size, use_bias=False, name="output_projection")
                #self.final_cell = self.atten_cell

                if self.mode == 'TRAIN' or self.mode == 'EVAL':
                    self.decoder_emb_inp = tf.nn.embedding_lookup(
                        self.embedding_matrix, self.target_input)
                    helper = tf.contrib.seq2seq.TrainingHelper(
                        inputs=self.decoder_emb_inp,
                        sequence_length=self.target_length)
                    decoder = tf.contrib.seq2seq.BasicDecoder(
                        cell=self.final_cell,
                        helper=helper,
                        initial_state=self.decoder_initial_state,
                        output_layer=self.output_layer)
                    final_outputs, final_state, final_seq_lengths = tf.contrib.seq2seq.dynamic_decode(
                        decoder=decoder, swap_memory=True)

                elif self.mode == 'INFER':

                    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                        embedding=self.embedding_matrix,
                        start_tokens=tf.fill([batch_size], self.tgt_sos_id),
                        end_token=self.tgt_eos_id)
                    decoder = tf.contrib.seq2seq.BasicDecoder(
                        cell=self.final_cell,
                        helper=helper,
                        initial_state=self.decoder_initial_state,
                        output_layer=self.output_layer)
                    final_outputs, final_state, final_seq_lengths = tf.contrib.seq2seq.dynamic_decode(
                        decoder=decoder,
                        swap_memory=True,
                        maximum_iterations=MAX_DECODE_STEP)

                else:
                    raise NotImplementedError
                # fw_cell.zero_state 's shape: num_layers * (batch_size,num_units)

            self.final_state = final_state
            self.logits = final_outputs.rnn_output
            self.sample_id = final_outputs.sample_id

            # build loss
            if self.mode == 'TRAIN' or self.mode == 'EVAL':
                with tf.variable_scope("Loss") as scope:
                    max_time = tf.shape(self.target_output)[1]
                    self.crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        labels=self.target_output, logits=self.logits)
                    self.target_weights = tf.sequence_mask(
                        self.target_length, max_time, dtype=self.logits.dtype)
                    self.loss = tf.reduce_mean(
                        tf.reduce_sum(self.crossent * self.target_weights,
                                      axis=-1) /
                        tf.to_float(self.target_length))
                    '''self.cost = tf.losses.sparse_softmax_cross_entropy(
                        labels = self.target_output,
                        logits = self.logits
                    )   # default is reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS
                    '''
                self.true_word = self.reverse_vocab_table.lookup(
                    tf.to_int64(self.target_output))

            if self.mode == 'TRAIN':

                #build train_op
                optimizer = tf.train.AdamOptimizer(learning_rate=lr)
                if self.grad_clip is None:
                    self.train_op = optimizer.minimize(self.loss)
                else:
                    tvars = tf.trainable_variables()
                    grads, _ = tf.clip_by_global_norm(
                        tf.gradients(self.loss, tvars),
                        tf.constant(self.grad_clip, dtype=tf.float32))
                    self.train_op = optimizer.apply_gradients(zip(
                        grads, tvars))

            self.probs = tf.nn.softmax(self.logits)
            self.predict = self.sample_id

            self.sample_word = self.reverse_vocab_table.lookup(
                tf.to_int64(self.sample_id))

            self.saver = tf.train.Saver(tf.global_variables(),
                                        max_to_keep=MAX_TO_KEEP)
Example #4
0
    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                            source_sequence_length):
        """Build a RNN cell with GNMT attention architecture."""
        attention_option = hparams.attention
        attention_architecture = hparams.attention_architecture
        num_units = hparams.num_units
        num_layers = hparams.num_layers
        num_residual_layers = hparams.num_residual_layers
        beam_width = hparams.beam_width

        dtype = tf.float32

        if self.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])
        else:
            memory = encoder_outputs

        if self.mode == tf.contrib.learn.ModeKeys.INFER and beam_width > 0:
            memory = tf.contrib.seq2seq.tile_batch(memory,
                                                   multiplier=beam_width)
            source_sequence_length = tf.contrib.seq2seq.tile_batch(
                source_sequence_length, multiplier=beam_width)
            encoder_state = tf.contrib.seq2seq.tile_batch(
                encoder_state, multiplier=beam_width)
            batch_size = self.batch_size * beam_width
        else:
            batch_size = self.batch_size

        attention_mechanism = self.attention_mechanism_fn(
            attention_option, num_units, memory, source_sequence_length,
            self.mode)

        cell_list = model_helper._cell_list(  # pylint: disable=protected-access
            unit_type=hparams.unit_type,
            num_units=num_units,
            num_layers=num_layers,
            num_residual_layers=num_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=hparams.num_gpus,
            mode=self.mode,
            single_cell_fn=self.single_cell_fn,
            residual_fn=gnmt_residual_fn)

        # Only wrap the bottom layer with the attention mechanism.
        attention_cell = cell_list.pop(0)

        # Only generate alignment in greedy INFER mode.
        alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER
                             and beam_width == 0)
        attention_cell = tf.contrib.seq2seq.AttentionWrapper(
            attention_cell,
            attention_mechanism,
            attention_layer_size=None,  # don't use attention layer.
            output_attention=False,
            alignment_history=alignment_history,
            name="attention")

        if attention_architecture == "gnmt":
            cell = GNMTAttentionMultiCell(attention_cell, cell_list)
        elif attention_architecture == "gnmt_v2":
            cell = GNMTAttentionMultiCell(attention_cell,
                                          cell_list,
                                          use_new_attention=True)
        else:
            raise ValueError("Unknown attention_architecture %s" %
                             attention_architecture)

        if hparams.pass_hidden_state:
            decoder_initial_state = tuple(
                zs.clone(cell_state=es) if isinstance(
                    zs, tf.contrib.seq2seq.AttentionWrapperState) else es
                for zs, es in zip(cell.zero_state(batch_size, dtype),
                                  encoder_state))
        else:
            decoder_initial_state = cell.zero_state(batch_size, dtype)

        return cell, decoder_initial_state
  def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                          source_sequence_length):
    """Build a RNN cell with GNMT attention architecture."""
    # GNMT attention
    assert self.is_gnmt_attention
    attention_option = hparams.attention
    attention_architecture = hparams.attention_architecture
    assert attention_option == "normed_bahdanau"
    assert attention_architecture == "gnmt_v2"

    num_units = hparams.num_units
    infer_mode = hparams.infer_mode
    dtype = tf.float16 if hparams.use_fp16 else tf.float32

    if self.time_major:
      memory = tf.transpose(encoder_outputs, [1, 0, 2])
    else:
      memory = encoder_outputs

    if (self.mode == tf.contrib.learn.ModeKeys.INFER and
        infer_mode == "beam_search"):
      memory, source_sequence_length, encoder_state, batch_size = (
          self._prepare_beam_search_decoder_inputs(
              hparams.beam_width, memory, source_sequence_length,
              encoder_state))
    else:
      batch_size = self.batch_size

    attention_mechanism = model.create_attention_mechanism(
        num_units, memory, source_sequence_length, dtype=dtype)

    cell_list = model_helper._cell_list(  # pylint: disable=protected-access
        unit_type=hparams.unit_type,
        num_units=num_units,
        num_layers=self.num_decoder_layers,
        num_residual_layers=self.num_decoder_residual_layers,
        forget_bias=hparams.forget_bias,
        dropout=hparams.dropout,
        mode=self.mode,
        dtype=dtype,
        single_cell_fn=self.single_cell_fn,
        residual_fn=gnmt_residual_fn,
        use_block_lstm=hparams.use_block_lstm)

    # Only wrap the bottom layer with the attention mechanism.
    attention_cell = cell_list.pop(0)

    # Only generate alignment in greedy INFER mode.
    alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
                         infer_mode != "beam_search")
    attention_cell = attention_wrapper.AttentionWrapper(
        attention_cell,
        attention_mechanism,
        attention_layer_size=None,  # don't use attention layer.
        output_attention=False,
        alignment_history=alignment_history,
        name="attention")
    cell = GNMTAttentionMultiCell(attention_cell, cell_list)

    if hparams.pass_hidden_state:
      decoder_initial_state = tuple(
        zs.clone(cell_state=es)
        if isinstance(zs, attention_wrapper.AttentionWrapperState) else es
        for zs, es in zip(
            cell.zero_state(batch_size, dtype), encoder_state))
    else:
      decoder_initial_state = cell.zero_state(batch_size, dtype)

    return cell, decoder_initial_state