Example #1
0
    def _build_decoder(self, encoder_outputs, encoder_state, hparams):
        """Build and run a RNN decoder with a final projection layer.

        Args:
          encoder_outputs: The outputs of encoder for every time step.
          encoder_state: The final state of the encoder.
          hparams: The Hyperparameters configurations.

        Returns:
          A tuple of final logits and final decoder state:
            logits: size [time, batch_size, vocab_size] when time_major=True.
        """
        tgt_sos_id = tf.cast(
            self.tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32)
        tgt_eos_id = tf.cast(
            self.tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32)
        iterator = self.iterator

        # maximum_iteration: The maximum decoding steps.
        maximum_iterations = self._get_infer_maximum_iterations(
            hparams, iterator.source_sequence_length)

        # Decoder.
        with tf.variable_scope("decoder") as decoder_scope:
            cell, decoder_initial_state = self._build_decoder_cell(
                hparams, encoder_outputs, encoder_state,
                iterator.source_sequence_length)

            # Optional ops depends on which mode we are in and which loss function we
            # are using.
            logits = tf.no_op()
            decoder_cell_outputs = None

            # Train or eval
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                # decoder_emp_inp: [max_time, batch_size, num_units]
                target_input = iterator.target_input
                if self.time_major:
                    target_input = tf.transpose(target_input)
                decoder_emb_inp = tf.nn.embedding_lookup(
                    self.embedding_decoder, target_input)

                # Helper
                helper = tf.contrib.seq2seq.TrainingHelper(
                    decoder_emb_inp,
                    iterator.target_sequence_length,
                    time_major=self.time_major)

                # Decoder
                my_decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell,
                    helper,
                    decoder_initial_state,
                )

                # Dynamic decoding
                outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(
                    my_decoder,
                    output_time_major=self.time_major,
                    swap_memory=True,
                    scope=decoder_scope)

                sample_id = outputs.sample_id

                if self.num_sampled_softmax > 0:
                    # Note: this is required when using sampled_softmax_loss.
                    decoder_cell_outputs = outputs.rnn_output

                # Note: there's a subtle difference here between train and inference.
                # We could have set output_layer when create my_decoder
                #   and shared more code between train and inference.
                # We chose to apply the output_layer to all timesteps for speed:
                #   10% improvements for small models & 20% for larger ones.
                # If memory is a concern, we should apply output_layer per timestep.
                num_layers = self.num_decoder_layers
                num_gpus = self.num_gpus
                device_id = num_layers if num_layers < num_gpus else (
                    num_layers - 1)
                # Colocate output layer with the last RNN cell if there is no extra GPU
                # available. Otherwise, put last layer on a separate GPU.
                with tf.device(model_helper.get_device_str(
                        device_id, num_gpus)):
                    logits = self.output_layer(outputs.rnn_output)

                if self.num_sampled_softmax > 0:
                    logits = tf.no_op(
                    )  # unused when using sampled softmax loss.

            # Inference
            else:
                infer_mode = hparams.infer_mode
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id
                utils.print_out(
                    "  decoder: infer_mode=%sbeam_width=%d, length_penalty=%f"
                    % (infer_mode, hparams.beam_width,
                       hparams.length_penalty_weight))

                if infer_mode == "beam_search":
                    beam_width = hparams.beam_width
                    length_penalty_weight = hparams.length_penalty_weight

                    my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                        cell=cell,
                        embedding=self.embedding_decoder,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=decoder_initial_state,
                        beam_width=beam_width,
                        output_layer=self.output_layer,
                        length_penalty_weight=length_penalty_weight)
                elif infer_mode == "sample":
                    # Helper
                    sampling_temperature = hparams.sampling_temperature
                    assert sampling_temperature > 0.0, (
                        "sampling_temperature must greater than 0.0 when using sample"
                        " decoder.")
                    helper = tf.contrib.seq2seq.SampleEmbeddingHelper(
                        self.embedding_decoder,
                        start_tokens,
                        end_token,
                        softmax_temperature=sampling_temperature,
                        seed=self.random_seed)
                elif infer_mode == "greedy":
                    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                        self.embedding_decoder, start_tokens, end_token)
                else:
                    raise ValueError("Unknown infer_mode '%s'", infer_mode)

                if infer_mode != "beam_search":
                    my_decoder = tf.contrib.seq2seq.BasicDecoder(
                        cell,
                        helper,
                        decoder_initial_state,
                        output_layer=self.output_layer  # applied per timestep
                    )

                # Dynamic decoding
                outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(
                    my_decoder,
                    maximum_iterations=maximum_iterations,
                    output_time_major=self.time_major,
                    swap_memory=True,
                    scope=decoder_scope)

                if infer_mode == "beam_search":
                    sample_id = outputs.predicted_ids
                else:
                    logits = outputs.rnn_output
                    sample_id = outputs.sample_id

        return logits, decoder_cell_outputs, sample_id, final_context_state
Example #2
0
    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                            source_sequence_length):
        """Build a RNN cell with attention mechanism that can be used by decoder."""
        # No Attention
        if not self.has_attention:
            return super(AttentionModel,
                         self)._build_decoder_cell(hparams, encoder_outputs,
                                                   encoder_state,
                                                   source_sequence_length)
        elif hparams.attention_architecture != "standard":
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

        num_units = hparams.num_units
        num_layers = self.num_decoder_layers
        num_residual_layers = self.num_decoder_residual_layers
        infer_mode = hparams.infer_mode

        dtype = tf.float32

        # Ensure memory is batch-major
        if self.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])
        else:
            memory = encoder_outputs

        if self.mode == tf.contrib.learn.ModeKeys.INFER and infer_mode == "beam_search":
            memory, source_sequence_length, encoder_state, batch_size = (
                self._prepare_beam_search_decoder_inputs(
                    hparams.beam_width, memory, source_sequence_length,
                    encoder_state))
        else:
            batch_size = self.batch_size

        # Attention
        attention_mechanism = self.attention_mechanism_fn(
            hparams.attention, num_units, memory, source_sequence_length,
            self.mode)

        cell = model_helper.create_rnn_cell(
            unit_type=hparams.unit_type,
            num_units=num_units,
            num_layers=num_layers,
            num_residual_layers=num_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=self.num_gpus,
            mode=self.mode,
            single_cell_fn=self.single_cell_fn)

        # Only generate alignment in greedy INFER mode.
        alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER
                             and infer_mode != "beam_search")
        cell = tf.contrib.seq2seq.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=num_units,
            alignment_history=alignment_history,
            output_attention=hparams.output_attention,
            name="attention")

        # TODO(thangluong): do we need num_layers, num_gpus?
        cell = tf.contrib.rnn.DeviceWrapper(
            cell, model_helper.get_device_str(num_layers - 1, self.num_gpus))

        if hparams.pass_hidden_state:
            decoder_initial_state = cell.zero_state(
                batch_size, dtype).clone(cell_state=encoder_state)
        else:
            decoder_initial_state = cell.zero_state(batch_size, dtype)

        return cell, decoder_initial_state
Example #3
0
    def build_graph(self, hparams, scope=None):
        """Subclass must implement this method.

        Creates a sequence-to-sequence model with dynamic RNN decoder API.
        Args:
          hparams: Hyperparameter configurations.
          scope: VariableScope for the created subgraph; default "dynamic_seq2seq".

        Returns:
          A tuple of the form (logits, loss_tuple, final_context_state, sample_id),
          where:
            logits: float32 Tensor [batch_size x num_decoder_symbols].
            loss: loss = the total loss / batch_size.
            final_context_state: the final state of decoder RNN.
            sample_id: sampling indices.

        Raises:
          ValueError: if encoder_type differs from mono and bi, or
            attention_option is not (luong | scaled_luong |
            bahdanau | normed_bahdanau).
        """
        utils.print_out("# Creating %s graph ..." % self.mode)

        # Projection
        if not self.extract_encoder_layers:
            with tf.variable_scope(scope or "build_network"):
                with tf.variable_scope("decoder/output_projection"):
                    if hparams.projection_type == 'sparse':
                        self.output_layer = core_layers.MaskedFullyConnected(
                            hparams.tgt_vocab_size,
                            use_bias=False,
                            name="output_projection")
                    elif hparams.projection_type == 'dense':
                        self.output_layer = tf.layers.Dense(
                            hparams.tgt_vocab_size,
                            use_bias=False,
                            name="output_projection")
                    else:
                        raise ValueError("Unknown projection type %s!" %
                                         hparams.projection_type)

        with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype):
            # Encoder
            if hparams.language_model:  # no encoder for language modeling
                utils.print_out("  language modeling: no encoder")
                self.encoder_outputs = None
                encoder_state = None
            else:
                self.encoder_outputs, encoder_state = self._build_encoder(
                    hparams)

            # Skip decoder if extracting only encoder layers
            if self.extract_encoder_layers:
                return

            # Decoder
            logits, decoder_cell_outputs, sample_id, final_context_state = (
                self._build_decoder(self.encoder_outputs, encoder_state,
                                    hparams))

            # Loss
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                with tf.device(
                        model_helper.get_device_str(
                            self.num_encoder_layers - 1, self.num_gpus)):
                    loss = self._compute_loss(logits, decoder_cell_outputs)
            else:
                loss = tf.constant(0.0)

            # model pruning
            if hparams.pruning_hparams is not None:
                pruning_hparams = pruning.get_pruning_hparams().parse(
                    hparams.pruning_hparams)
                self.p = pruning.Pruning(pruning_hparams,
                                         global_step=self.global_step)
                self.mask_update_op = self.p.conditional_mask_update_op()
                masks = get_masks()
                thresholds = get_thresholds()
                masks_s = []
                for index, mask in enumerate(masks):
                    masks_s.append(
                        tf.summary.scalar(mask.name + '/sparsity',
                                          tf.nn.zero_fraction(mask)))
                    masks_s.append(
                        tf.summary.scalar(
                            thresholds[index].op.name + '/threshold',
                            thresholds[index]))
                    masks_s.append(
                        tf.summary.histogram(mask.name + '/mask_tensor', mask))
                self.pruning_summary = tf.summary.merge([
                    tf.summary.scalar('sparsity', self.p._sparsity),
                    tf.summary.scalar('last_mask_update_step',
                                      self.p._last_update_step)
                ] + masks_s)
            else:
                self.mask_update_op = tf.no_op()
                self.pruning_summary = tf.no_op()

            return logits, loss, final_context_state, sample_id