Exemplo n.º 1
0
    def __build_encoder(self, params, keep_prob):
        with variable_scope.variable_scope("encoder") as scope:
            iterator = self.iterator
            encoder_embedded_inputs = tf.nn.embedding_lookup(params=self.embeddings, ids=iterator.sources)

            if params.encoder_type == "uni":
                cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, self.num_layers,
                                               use_residual=params.residual,
                                               input_keep_prob=keep_prob,
                                               devices=self.round_robin.assign(self.num_layers))

                encoder_outputs, encoder_state = tf.nn.dynamic_rnn(cell, inputs=encoder_embedded_inputs,
                                                                   sequence_length=iterator.source_sequence_lengths,
                                                                   dtype=self.dtype,
                                                                   swap_memory=True)
                return encoder_outputs, encoder_state
            elif params.encoder_type == "bi":
                num_bi_layers = int(params.num_layers / 2)

                fw_cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, num_bi_layers,
                                                  use_residual=params.residual,
                                                  input_keep_prob=keep_prob,
                                                  devices=self.round_robin.assign(num_bi_layers))
                bw_cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, num_bi_layers,
                                                  use_residual=params.residual,
                                                  input_keep_prob=keep_prob,
                                                  devices=self.round_robin.assign(num_bi_layers, self.device_manager.num_available_gpus()-1))

                encoder_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn(
                    fw_cell,
                    bw_cell,
                    encoder_embedded_inputs,
                    dtype=self.dtype,
                    sequence_length=iterator.source_sequence_lengths,
                    swap_memory=True)

                if num_bi_layers == 1:
                    encoder_state = bi_state
                else:
                    # alternatively concat forward and backward states
                    encoder_state = []
                    for layer_id in range(num_bi_layers):
                        encoder_state.append(bi_state[0][layer_id])  # forward
                        encoder_state.append(bi_state[1][layer_id])  # backward
                    encoder_state = tuple(encoder_state)

                return encoder_outputs, encoder_state
            else:
                raise ValueError("Unknown encoder type: %s" % params.encoder_type)
Exemplo n.º 2
0
    def __build_context(self, params, encoder_results, initial_state, keep_prob, device):
        with variable_scope.variable_scope("context") as scope:
            context_seq_length = tf.fill([self.batch_size], self.num_turns)

            with tf.device(device):
                context_inputs = tf.stack([inp for _, inp in encoder_results], axis=0)

                if params.context_type == "uni":
                    cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, params.num_layers,
                                                   use_residual=params.residual,
                                                   input_keep_prob=keep_prob,
                                                   devices=self.round_robin.assign(params.num_layers))

                    _, context_state = tf.nn.dynamic_rnn(cell,
                                                         initial_state=initial_state,
                                                         inputs=context_inputs,
                                                         sequence_length=context_seq_length,
                                                         time_major=True,
                                                         dtype=scope.dtype,
                                                         swap_memory=True)
                    return context_state
                elif params.context_type == "bi":
                    num_bi_layers = int(params.num_layers / 2)
                    fw_cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, num_bi_layers,
                                                      use_residual=params.residual,
                                                      input_keep_prob=keep_prob,
                                                      devices=self.round_robin.assign(num_bi_layers))
                    bw_cell = rnn_factory.create_cell(params.cell_type, params.hidden_units, num_bi_layers,
                                                      use_residual=params.residual,
                                                      input_keep_prob=keep_prob,
                                                      devices=self.round_robin.assign(num_bi_layers,
                                                                                      self.device_manager.num_available_gpus() - 1))

                    _, context_state = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell,
                                                                       context_inputs,
                                                                       initial_state_fw=initial_state[0],
                                                                       initial_state_bw=initial_state[1],
                                                                       sequence_length=context_seq_length,
                                                                       time_major=True,
                                                                       dtype=scope.dtype,
                                                                       swap_memory=True)
                    fw_state, bw_state = context_state
                    return self._merge_bidirectional_states(num_bi_layers, fw_state, bw_state)
                else:
                    raise ValueError("Unknown context type: %s" % params.context_type)
    def __build_decoder_cell(self, params, mode, encoder_outputs,
                             encoder_state, keep_prob):
        cell = rnn_factory.create_cell(params.cell_type,
                                       params.hidden_units,
                                       self.num_layers,
                                       use_residual=params.residual,
                                       input_keep_prob=keep_prob,
                                       devices=self.round_robin.assign(
                                           self.num_layers))

        if mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width > 0:
            batch_size = self.batch_size * params.beam_width
            decoder_initial_state = tf.contrib.seq2seq.tile_batch(
                encoder_state, multiplier=params.beam_width)
            memory = tf.contrib.seq2seq.tile_batch(
                encoder_outputs, multiplier=params.beam_width)
            source_sequence_length = tf.contrib.seq2seq.tile_batch(
                self.iterator.source_sequence_lengths,
                multiplier=params.beam_width)
        else:
            batch_size = self.batch_size
            decoder_initial_state = encoder_state
            memory = encoder_outputs
            source_sequence_length = self.iterator.source_sequence_lengths

        try:
            attn_mechanism = attention_helper.create_attention_mechanism(
                params.attention_type, params.hidden_units, memory,
                source_sequence_length)

            alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width == 0
            cell = tf.contrib.seq2seq.AttentionWrapper(
                cell,
                attention_mechanism=attn_mechanism,
                attention_layer_size=params.hidden_units,
                alignment_history=alignment_history,
                output_attention=True,
                name="vanilla_attention")

            decoder_initial_state = cell.zero_state(
                batch_size, self.dtype).clone(cell_state=decoder_initial_state)
        except ValueError:
            pass

        return cell, decoder_initial_state
Exemplo n.º 4
0
    def __build_decoder_cell(self, params, context_outputs, context_state,
                             input_keep_prob):
        cell = rnn_factory.create_cell(params.cell_type,
                                       params.hidden_units,
                                       use_residual=params.residual,
                                       num_layers=params.num_layers,
                                       input_keep_prob=input_keep_prob,
                                       devices=self.round_robin.assign(
                                           params.num_layers))

        topical_embeddings = tf.nn.embedding_lookup(self.embeddings,
                                                    self.iterator.topic)

        max_topic_length = tf.reduce_max(self.iterator.topic_sequence_length)

        expanded_context_state = tf.tile(
            tf.expand_dims(
                context_state[-1] if params.num_layers > 1 else context_state,
                axis=1), [1, max_topic_length, 1])
        topical_embeddings = tf.concat(
            [expanded_context_state, topical_embeddings], axis=2)

        context_sequence_length = tf.fill([self.batch_size], self.num_turns)
        batch_majored_context_outputs = tf.transpose(context_outputs,
                                                     [1, 0, 2])

        if self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width > 0:
            batch_size = self.batch_size * params.beam_width

            decoder_initial_state = tf.contrib.seq2seq.tile_batch(
                context_state, multiplier=params.beam_width)

            memory = tf.contrib.seq2seq.tile_batch(
                batch_majored_context_outputs, multiplier=params.beam_width)
            topical_embeddings = tf.contrib.seq2seq.tile_batch(
                topical_embeddings, multiplier=params.beam_width)
            context_sequence_length = tf.contrib.seq2seq.tile_batch(
                context_sequence_length, multiplier=params.beam_width)
            topic_sequence_length = tf.contrib.seq2seq.tile_batch(
                self.iterator.topic_sequence_length,
                multiplier=params.beam_width)
        else:
            batch_size = self.batch_size
            decoder_initial_state = context_state
            memory = batch_majored_context_outputs
            topic_sequence_length = self.iterator.topic_sequence_length

        context_attention = attention_helper.create_attention_mechanism(
            params.attention_type, params.hidden_units, memory,
            context_sequence_length)

        topical_attention = attention_helper.create_attention_mechanism(
            params.attention_type, params.hidden_units, topical_embeddings,
            topic_sequence_length)

        alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width == 0
        cell = tf.contrib.seq2seq.AttentionWrapper(
            cell,
            attention_mechanism=(context_attention, topical_attention),
            attention_layer_size=(params.hidden_units, params.hidden_units),
            alignment_history=alignment_history,
            output_attention=True,
            name="joint_attention")

        decoder_initial_state = cell.zero_state(
            batch_size, self.dtype).clone(cell_state=decoder_initial_state)

        return cell, decoder_initial_state
Exemplo n.º 5
0
    def __build_context(self, params, encoder_results, initial_state,
                        keep_prob, device):
        with variable_scope.variable_scope("context") as scope:
            with tf.device(device):
                context_seq_length = tf.fill([self.batch_size], self.num_turns)
                context_inputs = tf.stack(
                    [state for _, state in encoder_results], axis=0)

                # message_attention = attention_helper.create_attention_mechanism(params.attention_type,
                #                                                                 params.hidden_units,
                #                                                                 context_inputs)

                if params.context_type == "uni":
                    cell = rnn_factory.create_cell(
                        params.cell_type,
                        params.hidden_units,
                        use_residual=params.residual,
                        num_layers=params.num_layers,
                        input_keep_prob=keep_prob)

                    # cell = tf.contrib.seq2seq.AttentionWrapper(
                    #     cell,
                    #     msg_attn_mechanism,
                    #     attention_layer_size=params.hidden_units,
                    #     alignment_history=False,
                    #     output_attention=True,
                    #     name="message_attention")
                    context_outputs, context_state = tf.nn.dynamic_rnn(
                        cell,
                        initial_state=initial_state,
                        inputs=context_inputs,
                        sequence_length=context_seq_length,
                        time_major=True,
                        dtype=self.dtype,
                        swap_memory=True)

                    return context_outputs, context_state
                elif params.context_type == "bi":
                    num_bi_layers = int(params.num_layers / 2)
                    fw_cell = rnn_factory.create_cell(
                        params.cell_type,
                        params.hidden_units,
                        num_bi_layers,
                        use_residual=params.residual,
                        input_keep_prob=keep_prob,
                        devices=self.round_robin.assign(num_bi_layers))
                    bw_cell = rnn_factory.create_cell(
                        params.cell_type,
                        params.hidden_units,
                        num_bi_layers,
                        use_residual=params.residual,
                        input_keep_prob=keep_prob,
                        devices=self.round_robin.assign(
                            num_bi_layers,
                            self.device_manager.num_available_gpus() - 1))

                    context_outputs, context_state = tf.nn.bidirectional_dynamic_rnn(
                        fw_cell,
                        bw_cell,
                        context_inputs,
                        initial_state_fw=initial_state[0],
                        initial_state_bw=initial_state[1],
                        sequence_length=context_seq_length,
                        time_major=True,
                        dtype=scope.dtype,
                        swap_memory=True)

                    fw_state, bw_state = context_state
                    fw_output, bw_output = context_outputs
                    context_outputs = tf.concat([fw_output, bw_output],
                                                axis=-1)
                    return context_outputs, self._merge_bidirectional_states(
                        num_bi_layers, fw_state, bw_state)
                else:
                    raise ValueError("Unknown context type: %s" %
                                     params.context_type)
Exemplo n.º 6
0
    def __build_encoder(self, params, keep_prob):
        encoder_cell = {}

        if params.encoder_type == "uni":
            encoder_cell['uni'] = rnn_factory.create_cell(
                params.cell_type,
                params.hidden_units,
                use_residual=params.residual,
                num_layers=params.num_layers,
                input_keep_prob=keep_prob)
        elif params.encoder_type == "bi":
            num_bi_layers = int(params.num_layers / 2)
            encoder_cell['fw'] = rnn_factory.create_cell(
                params.cell_type,
                params.hidden_units,
                use_residual=params.residual,
                num_layers=num_bi_layers,
                input_keep_prob=keep_prob)
            encoder_cell['bw'] = rnn_factory.create_cell(
                params.cell_type,
                params.hidden_units,
                use_residual=params.residual,
                num_layers=num_bi_layers,
                input_keep_prob=keep_prob)
        else:
            raise ValueError("Unknown encoder type: '%s'" %
                             params.encoder_type)

        encoding_devices = self.round_robin.assign(self.num_turns)

        encoder_results, next_initial_state = [], None
        for t in range(self.num_turns):
            with variable_scope.variable_scope("encoder") as scope:
                if t > 0:
                    scope.reuse_variables()

                with tf.device(encoding_devices[t]):
                    encoder_embedded_inputs = tf.nn.embedding_lookup(
                        params=self.embeddings, ids=self.iterator.sources[t])

                    if params.encoder_type == "bi":
                        encoder_outputs, states = tf.nn.bidirectional_dynamic_rnn(
                            encoder_cell['fw'],
                            encoder_cell['bw'],
                            inputs=encoder_embedded_inputs,
                            dtype=self.dtype,
                            sequence_length=self.iterator.
                            source_sequence_lengths[t],
                            swap_memory=True)

                        fw_state, bw_state = states
                        num_bi_layers = int(params.num_layers / 2)
                        if t == 0:
                            if params.context_type == "uni":
                                next_initial_state = self._merge_bidirectional_states(
                                    num_bi_layers, fw_state, bw_state)
                            else:
                                if num_bi_layers > 1:
                                    initial_state_fw, initial_state_bw = [], []
                                    for layer_id in range(num_bi_layers):
                                        initial_state_fw.append(
                                            fw_state[layer_id])
                                        initial_state_bw.append(
                                            bw_state[layer_id])

                                    next_initial_state = (
                                        tuple(initial_state_fw),
                                        tuple(initial_state_bw))
                                else:
                                    next_initial_state = (fw_state, bw_state)

                        if num_bi_layers > 1:
                            next_input = tf.concat(
                                [fw_state[-1], bw_state[-1]], axis=1)
                        else:
                            next_input = tf.concat([fw_state, bw_state],
                                                   axis=1)
                    else:
                        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                            encoder_cell['uni'],
                            inputs=encoder_embedded_inputs,
                            sequence_length=self.iterator.
                            source_sequence_lengths[t],
                            dtype=self.dtype,
                            swap_memory=True,
                            scope=scope)

                        if t == 0:
                            if params.context_type == "uni":
                                next_initial_state = encoder_state
                            else:
                                num_bi_layers = int(params.num_layers / 2)
                                initial_state_fw, initial_state_bw = [], []
                                for layer_id in range(num_bi_layers):
                                    initial_state_fw.append(
                                        encoder_state[2 * layer_id])
                                    initial_state_bw.append(
                                        encoder_state[2 * layer_id + 1])
                                next_initial_state = (tuple(initial_state_fw),
                                                      tuple(initial_state_bw))

                        if params.num_layers > 1:
                            next_input = encoder_state[-1]
                        else:
                            next_input = encoder_state

                    # msg_attn_mechanism = attention_helper.create_attention_mechanism(
                    #     params.attention_type,
                    #     params.hidden_units,
                    #     encoder_outputs,
                    #     self.iterator.source_sequence_lengths[t])

                    encoder_results.append((encoder_outputs, next_input))

        return encoder_results, next_initial_state
Exemplo n.º 7
0
    def __build_decoder_cell(self, params, encoder_outputs, encoder_state,
                             keep_prob):
        cell = rnn_factory.create_cell(params.cell_type,
                                       params.hidden_units,
                                       self.num_layers,
                                       use_residual=params.residual,
                                       input_keep_prob=keep_prob,
                                       devices=self.round_robin.assign(
                                           self.num_layers))

        topical_embeddings = tf.nn.embedding_lookup(self.embeddings,
                                                    self.iterator.topic)

        max_topic_length = tf.reduce_max(self.iterator.topic_sequence_length)

        aggregated_state = encoder_state
        if isinstance(encoder_state, tuple):
            aggregated_state = encoder_state[0]
            for state in encoder_state[1:]:
                aggregated_state = tf.concat([aggregated_state, state], axis=1)

        if isinstance(encoder_outputs, tuple):
            aggregated_outputs = encoder_outputs[0]
            for output in encoder_outputs[1:]:
                aggregated_outputs = tf.concat([aggregated_outputs, output],
                                               axis=1)

            encoder_outputs = aggregated_outputs

        expanded_encoder_state = tf.tile(
            tf.expand_dims(aggregated_state, axis=1), [1, max_topic_length, 1])
        topical_embeddings = tf.concat(
            [expanded_encoder_state, topical_embeddings], axis=2)

        if self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width > 0:
            batch_size = self.batch_size * params.beam_width

            if isinstance(encoder_state, tuple):
                decoder_initial_state = tuple([
                    tf.contrib.seq2seq.tile_batch(state,
                                                  multiplier=params.beam_width)
                    for state in encoder_state
                ])
            else:
                decoder_initial_state = tf.contrib.seq2seq.tile_batch(
                    encoder_state, multiplier=params.beam_width)

            memory = tf.contrib.seq2seq.tile_batch(
                encoder_outputs, multiplier=params.beam_width)
            topical_embeddings = tf.contrib.seq2seq.tile_batch(
                topical_embeddings, multiplier=params.beam_width)
            source_sequence_length = tf.contrib.seq2seq.tile_batch(
                self.iterator.source_sequence_lengths,
                multiplier=params.beam_width)
            topic_sequence_length = tf.contrib.seq2seq.tile_batch(
                self.iterator.topic_sequence_length,
                multiplier=params.beam_width)
        else:
            batch_size = self.batch_size
            decoder_initial_state = encoder_state
            memory = encoder_outputs
            source_sequence_length = self.iterator.source_sequence_lengths
            topic_sequence_length = self.iterator.topic_sequence_length

        message_attention = attention_helper.create_attention_mechanism(
            params.attention_type, params.hidden_units, memory,
            source_sequence_length)

        topical_attention = attention_helper.create_attention_mechanism(
            params.attention_type, params.hidden_units, topical_embeddings,
            topic_sequence_length)

        alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width == 0
        cell = tf.contrib.seq2seq.AttentionWrapper(
            cell,
            attention_mechanism=(message_attention, topical_attention),
            attention_layer_size=(params.hidden_units, params.hidden_units),
            alignment_history=alignment_history,
            output_attention=True,
            name="joint_attention")

        decoder_initial_state = cell.zero_state(
            batch_size, self.dtype).clone(cell_state=decoder_initial_state)

        return cell, decoder_initial_state
Exemplo n.º 8
0
    def __build_decoder(self, params, mode, initial_state, input_keep_prob,
                        device):

        iterator = self.iterator

        decoder_cell = rnn_factory.create_cell(params.cell_type,
                                               params.hidden_units,
                                               params.num_layers,
                                               use_residual=params.residual,
                                               input_keep_prob=input_keep_prob,
                                               devices=self.round_robin.assign(
                                                   params.num_layers))

        with variable_scope.variable_scope("decoder") as scope:
            with tf.device(device):
                if mode != tf.contrib.learn.ModeKeys.INFER:
                    # decoder_emp_inp: [max_time, batch_size, num_units]
                    decoder_emb_inp = tf.nn.embedding_lookup(
                        self.embeddings, iterator.target_input)

                    # Helper
                    if self.use_scheduled_sampling:
                        helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                            decoder_emb_inp, iterator.target_sequence_length,
                            self.embeddings, self.sampling_probability)
                    else:
                        helper = tf.contrib.seq2seq.TrainingHelper(
                            decoder_emb_inp, iterator.target_sequence_length)

                    # Decoder
                    my_decoder = tf.contrib.seq2seq.BasicDecoder(
                        decoder_cell, helper, initial_state)

                    # Dynamic decoding
                    outputs, final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
                        my_decoder, swap_memory=True, scope=scope)

                    sample_id = outputs.sample_id

                    # Note: there's a subtle difference here between train and inference.
                    # We could have set output_layer when create my_decoder
                    #   and shared more code between train and inference.
                    # We chose to apply the output_layer to all timesteps for speed:
                    #   10% improvements for small models & 20% for larger ones.
                    # If memory is a concern, we should apply output_layer per timestep.
                    logits = self.output_layer(outputs.rnn_output)

                ### Inference
                else:
                    beam_width = params.beam_width
                    start_tokens = tf.fill([self.batch_size], vocab.SOS_ID)
                    end_token = vocab.EOS_ID

                    maximum_iterations = self._get_decoder_max_iterations(
                        params)

                    if beam_width > 0:
                        initial_state = tf.contrib.seq2seq.tile_batch(
                            initial_state, multiplier=params.beam_width)

                        my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                            cell=decoder_cell,
                            embedding=self.embeddings,
                            start_tokens=start_tokens,
                            end_token=end_token,
                            initial_state=initial_state,
                            beam_width=beam_width,
                            output_layer=self.output_layer,
                            length_penalty_weight=params.length_penalty_weight)
                    else:
                        # Helper
                        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                            self.embeddings, start_tokens, end_token)

                        # Decoder
                        my_decoder = tf.contrib.seq2seq.BasicDecoder(
                            decoder_cell,
                            helper,
                            initial_state,
                            output_layer=self.
                            output_layer  # applied per timestep
                        )

                    # Dynamic decoding
                    outputs, final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
                        my_decoder,
                        maximum_iterations=maximum_iterations,
                        swap_memory=True,
                        scope=scope)

                    if beam_width > 0:
                        logits = tf.no_op()
                        sample_id = outputs.predicted_ids
                    else:
                        logits = outputs.rnn_output
                        sample_id = outputs.sample_id

        return logits, sample_id, final_decoder_state