Esempio n. 1
0
  def _build_decoder(self, encoder_outputs, encoder_state, hparams):
    """Build and run a RNN decoder with a final projection layer.

    Args:
      encoder_outputs: The outputs of encoder for every time step.
      encoder_state: The final state of the encoder.
      hparams: The Hyperparameters configurations.

    Returns:
      For inference, A tuple of final logits and final decoder state:
        logits: size [time, batch_size, vocab_size] when time_major=True.
      For training, returns the final loss
    """

    ## Decoder.
    with tf.variable_scope("decoder") as decoder_scope:
      cell, decoder_initial_state = self._build_decoder_cell(
          hparams, encoder_outputs, encoder_state,
          self.features["source_sequence_length"])

      # Optional ops depends on which mode we are in and which loss function we
      # are using.
      logits = tf.no_op()
      decoder_cell_outputs = None

      ## Train or eval
      if self.mode != tf.contrib.learn.ModeKeys.INFER:
        # decoder_emp_inp: [max_time, batch_size, num_units]
        target_input = self.features["target_input"]
        if self.time_major:
          target_input = tf.transpose(target_input)
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
          decoder_emb_inp = self._emb_lookup(
              self.embedding_decoder, target_input, is_decoder=True)
        else:
          decoder_emb_inp = tf.cast(
              tf.nn.embedding_lookup(self.embedding_decoder, target_input),
              self.dtype)

        if hparams.use_dynamic_rnn:
          final_rnn_outputs, _ = tf.nn.dynamic_rnn(
              cell,
              decoder_emb_inp,
              sequence_length=self.features["target_sequence_length"],
              initial_state=decoder_initial_state,
              dtype=self.dtype,
              scope=decoder_scope,
              time_major=self.time_major)
        else:
          final_rnn_outputs, _ = tf.contrib.recurrent.functional_rnn(
              cell,
              decoder_emb_inp,
              sequence_length=self.features["target_sequence_length"],
              initial_state=decoder_initial_state,
              dtype=self.dtype,
              scope=decoder_scope,
              time_major=self.time_major,
              use_tpu=hparams.use_tpu)

        # 512 batch dimension yields best tpu efficiency.
        factor = tf.maximum(1, 512 // self.batch_size)
        factored_batch = self.batch_size * factor
        input1 = tf.reshape(final_rnn_outputs,
                            [-1, factored_batch, self.num_units])
        input2 = tf.reshape(
            tf.transpose(self.features["target_output"]),
            [-1, factored_batch, 1])
        max_length = tf.reduce_max(self.features["target_sequence_length"])
        max_length = tf.where(
            tf.equal(max_length % factor, 0), max_length // factor,
            max_length // factor + 1)
        inputs = tf.concat(
            [tf.cast(input1, tf.float32),
             tf.cast(input2, tf.float32)], 2)

        loss, _ = tf.contrib.recurrent.Recurrent(
            theta=self.output_layer,
            state0=tf.zeros([512], tf.float32),
            inputs=inputs,
            cell_fn=self._compute_loss,
            max_input_length=max_length,
            use_tpu=True)

        return tf.reduce_sum(loss), None, None

      ## Inference
      else:
        assert hparams.infer_mode == "beam_search"
        start_tokens = tf.fill([self.batch_size], hparams.tgt_sos_id)
        end_token = hparams.tgt_eos_id
        beam_width = hparams.beam_width
        length_penalty_weight = hparams.length_penalty_weight
        coverage_penalty_weight = hparams.coverage_penalty_weight

        # maximum_iteration: The maximum decoding steps.
        maximum_iterations = self._get_infer_maximum_iterations(
            hparams, self.features["source_sequence_length"])

        mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_BEAM_SIZE,
                              value=beam_width)
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_MAX_SEQ_LEN,
                              value=maximum_iterations)
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_LEN_NORM_FACTOR,
                              value=length_penalty_weight)
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_HP_COV_PENALTY_FACTOR,
                              value=coverage_penalty_weight)
        my_decoder = beam_search_decoder.BeamSearchDecoder(
            cell=cell,
            embedding=self.embedding_decoder,
            start_tokens=start_tokens,
            end_token=end_token,
            initial_state=decoder_initial_state,
            beam_width=beam_width,
            output_layer=self.output_layer,
            max_tgt=maximum_iterations,
            length_penalty_weight=length_penalty_weight,
            coverage_penalty_weight=coverage_penalty_weight,
            dtype=self.dtype)

        # Dynamic decoding
        predicted_ids = decoder.dynamic_decode(
            my_decoder,
            maximum_iterations=maximum_iterations,
            output_time_major=self.time_major,
            swap_memory=True,
            scope=decoder_scope)

    return logits, decoder_cell_outputs, predicted_ids
Esempio n. 2
0
    def _build_decoder(self, encoder_outputs, hparams):
        """Build and run a RNN decoder with a final projection layer.

    Args:
      encoder_outputs: The outputs of encoder for every time step.
      hparams: The Hyperparameters configurations.

    Returns:
      For inference, A tuple of final logits and final decoder state:
        logits: size [time, batch_size, vocab_size]
      For training, returns the final loss
    """

        ## Decoder.
        with tf.variable_scope("decoder",
                               reuse=tf.AUTO_REUSE) as decoder_scope:
            # Optional ops depends on which mode we are in and which loss function we
            # are using.
            logits = tf.no_op()
            decoder_cell_outputs = None
            if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                beam_width = 1
            else:
                beam_width = hparams.beam_width
            theta, input_kernels, state0 = build_atten_rnn(
                encoder_outputs, self.features["source_sequence_length"],
                hparams.num_units, beam_width, "multi_rnn_cell")

            ## Train or eval
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                # decoder_emp_inp: [max_time, batch_size, num_units]
                target_input = self.features["target_input"]
                batch_size, max_time = target_input.shape
                target_input = tf.transpose(target_input)

                decoder_emb_inp = self._emb_lookup(self.embedding_decoder,
                                                   target_input,
                                                   is_decoder=True)

                seq_len = self.features["target_sequence_length"]
                padding = tf.transpose(
                    tf.sequence_mask(seq_len, target_input.shape[0],
                                     decoder_emb_inp.dtype))
                max_seq_len = tf.reduce_max(seq_len)
                o = decoder_emb_inp
                if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                    o = o * dropout(o.shape, o.dtype, 1.0 - hparams.dropout)
                inp = {
                    "rnn": input_projection(o, input_kernels[0], max_seq_len)
                }
                new_states = build_rnn(theta[0], state0[0], inp,
                                       attention_cell, attention_cell_grad,
                                       max_seq_len)
                attention_state = new_states["attention"]
                o = new_states["h"]
                for i in range(1, 4):
                    c = tf.concat([o, attention_state], -1)
                    if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
                        c = c * dropout(c.shape, c.dtype,
                                        1.0 - hparams.dropout)
                    inp = {
                        "rnn": input_projection(c, input_kernels[i],
                                                max_seq_len)
                    }
                    out = build_rnn(theta[i], state0[i], inp, lstm_cell,
                                    lstm_cell_grad, max_seq_len)
                    o = out["h"] + o if i > 1 else out["h"]

                out = o * tf.expand_dims(padding, 2)

                if batch_size * max_time < 1024:
                    return tf.reduce_sum(
                        self._compute_loss(self.output_layer, [
                            tf.reshape(out, [-1, self.num_units]),
                            tf.transpose(self.features["target_output"])
                        ])[0]), None, None

                # 512 batch dimension yields best tpu efficiency.
                factor = tf.maximum(1, 512 // self.batch_size)
                factored_batch = self.batch_size * factor
                input1 = tf.reshape(out, [-1, factored_batch, self.num_units])
                input2 = tf.reshape(
                    tf.transpose(self.features["target_output"]),
                    [-1, factored_batch, 1])
                max_length = tf.reduce_max(
                    self.features["target_sequence_length"])
                max_length = tf.where(tf.equal(max_length % factor,
                                               0), max_length // factor,
                                      max_length // factor + 1)
                inputs = [input1, input2]

                def _cell_fn(theta, _, state):
                    return self._compute_loss(theta, state, 512)

                loss, _ = tf.contrib.recurrent.Recurrent(
                    theta=self.output_layer,
                    state0=tf.zeros([512], tf.float32),
                    inputs=inputs,
                    cell_fn=_cell_fn,
                    max_input_length=max_length,
                    use_tpu=True)

                return tf.reduce_sum(loss), None, None

            ## Inference
            else:
                assert hparams.infer_mode == "beam_search"
                start_tokens = tf.fill([self.batch_size], hparams.tgt_sos_id)
                end_token = hparams.tgt_eos_id
                beam_width = hparams.beam_width
                length_penalty_weight = hparams.length_penalty_weight
                coverage_penalty_weight = hparams.coverage_penalty_weight

                # maximum_iteration: The maximum decoding steps.
                maximum_iterations = self._get_infer_maximum_iterations(
                    hparams, self.features["source_sequence_length"])

                def cell_fn(inputs, state):
                    """Cell function used in decoder."""
                    inp = {"rnn": tf.matmul(inputs, input_kernels[0])}
                    atten_state, _ = attention_cell(theta[0], state[0], inp)
                    o = atten_state["h"]
                    new_states = [atten_state]
                    for i in range(1, 4):
                        ns, _ = lstm_cell(
                            theta[i], state[i], {
                                "rnn":
                                tf.matmul(
                                    tf.concat([o, atten_state["attention"]],
                                              -1), input_kernels[i])
                            })
                        new_states.append(ns)
                        if i > 1:
                            o = ns["h"] + o
                        else:
                            o = ns["h"]
                    return new_states, o

                my_decoder = beam_search_decoder.BeamSearchDecoder(
                    cell=cell_fn,
                    embedding=self.embedding_decoder,
                    start_tokens=start_tokens,
                    end_token=end_token,
                    initial_state=state0,
                    beam_width=beam_width,
                    output_layer=self.output_layer,
                    max_tgt=maximum_iterations,
                    length_penalty_weight=length_penalty_weight,
                    coverage_penalty_weight=coverage_penalty_weight,
                    dtype=self.dtype)

                # Dynamic decoding
                predicted_ids = decoder.dynamic_decode(
                    my_decoder,
                    maximum_iterations=maximum_iterations,
                    swap_memory=True,
                    scope=decoder_scope)

        return logits, decoder_cell_outputs, predicted_ids
Esempio n. 3
0
    def make_model(self):
        self.placeholders['target_values'] = tf.placeholder(
            tf.float32, [None, self.max_num_vertices * self.max_num_vertices],
            name='target_values')
        self.placeholders['decoder_input'] = tf.placeholder(
            tf.float32, [None, self.max_num_vertices * self.max_num_vertices],
            name='decoder_input')
        self.placeholders['ini_state'] = tf.placeholder(
            tf.float32, [None, self.params['hidden_size']], name='ini_state')
        self.placeholders['num_graphs'] = tf.placeholder(tf.int32, [],
                                                         name='num_graphs')

        with tf.variable_scope("graph_model"):
            self.prepare_specific_graph_model()
            if self.params['use_graph']:
                self.ops[
                    'final_node_representations'] = self.compute_final_node_representations(
                    )
            else:
                self.ops['final_node_representations'] = tf.zeros_like(
                    self.placeholders['initial_node_representation'])

        self.ops['losses'] = []
        self.placeholders['target_sequence_length'] = tf.placeholder(
            tf.int32, [None], name='target_sequence_length')
        grRep = self.compressGr(self.ops['final_node_representations'])  #--,h
        grRep = tf.reshape(
            grRep,
            [-1, self.params["win"], self.params['hidden_size']])  # b,w,h
        self.placeholders['static_gr_rep'] = grRep

        stacked_cells = tf.contrib.rnn.MultiRNNCell([
            tf.contrib.rnn.DropoutWrapper(
                tf.contrib.rnn.LSTMCell(self.params['hidden_size']),
                self.params["dropout_keep_prob"])
            for _ in range(self.params["num_layers"])
        ])
        outputs_enc, encoder_state = tf.nn.dynamic_rnn(
            stacked_cells,
            grRep,
            sequence_length=self.placeholders['target_sequence_length'],
            dtype=tf.float32)
        self.placeholders['enc_output'] = outputs_enc
        dec_embed_input = tf.reshape(
            self.placeholders['decoder_input'],
            shape=(-1, self.params["win"],
                   self.max_num_vertices * self.max_num_vertices))
        cells = tf.contrib.rnn.MultiRNNCell([
            tf.contrib.rnn.LSTMCell(self.params['hidden_size'])
            for _ in range(self.params["num_layers"])
        ])
        dec_cell = tf.contrib.rnn.DropoutWrapper(
            cells, output_keep_prob=self.params["dropout_keep_prob"])
        max_target_len = tf.reduce_max(
            self.placeholders['target_sequence_length'])

        if (self.params["helper"] == "th"):
            helper = tf.contrib.seq2seq.TrainingHelper(
                dec_embed_input, self.placeholders['target_sequence_length'])
        else:
            helper = tf.contrib.seq2seq.ScheduledOutputTrainingHelper(
                dec_embed_input,
                self.placeholders['target_sequence_length'],
                sampling_probability=1.0)

        output_layer = Dense(self.max_num_vertices * self.max_num_vertices)
        decoder = BasicDecoder(dec_cell, helper, encoder_state,
                               self.placeholders['ini_state'], output_layer)
        outputs_dec, state_Dec, _, full_States = dynamic_decode(
            decoder,
            max_target_len,
            self.params["hidden_size"],
            impute_finished=True,
            maximum_iterations=max_target_len)
        self.placeholders['dec_output'] = outputs_dec
        self.placeholders['state_Dec'] = full_States

        masks = tf.sequence_mask(self.placeholders['target_sequence_length'],
                                 max_target_len,
                                 dtype=tf.float32,
                                 name='masks')  #b,tl(w)
        masks = tf.expand_dims(masks, 2)
        training_logits = tf.identity(outputs_dec.rnn_output, name='logits')
        training_logits = training_logits * masks
        training_logits = tf.reshape(
            training_logits,
            (-1, self.max_num_vertices * self.max_num_vertices))
        loss = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=training_logits, labels=self.placeholders['target_values'])
        loss = tf.reduce_sum(loss, axis=1)
        loss = tf.reduce_mean(loss)

        self.ops['loss'] = loss