예제 #1
0
파일: hred_model.py 프로젝트: shark-3/THRED
    def __build_context(self, params, encoder_results, keep_prob, device):
        with variable_scope.variable_scope("context") as scope:
            context_seq_length = tf.fill([self.batch_size], self.num_turns)

            with tf.device(device):
                context_inputs = tf.stack(
                    [state for outputs, state in encoder_results], axis=0)

                if params.context_type == "uni":
                    cell = rnn_factory.create_cell(params.cell_type,
                                                   params.decoder_hidden_units,
                                                   num_layers=1,
                                                   input_keep_prob=keep_prob,
                                                   devices=[device])

                    context_outputs, context_state = tf.nn.dynamic_rnn(
                        cell,
                        inputs=context_inputs,
                        sequence_length=context_seq_length,
                        time_major=True,
                        dtype=scope.dtype,
                        swap_memory=True)
                    return context_outputs, context_state
                elif params.context_type == "bi":
                    fw_cell = rnn_factory.create_cell(
                        params.cell_type,
                        params.decoder_hidden_units,
                        num_layers=1,
                        input_keep_prob=keep_prob,
                        devices=[device])
                    bw_cell = rnn_factory.create_cell(
                        params.cell_type,
                        params.decoder_hidden_units,
                        num_layers=1,
                        input_keep_prob=keep_prob,
                        devices=[device])

                    context_outputs, context_state = tf.nn.bidirectional_dynamic_rnn(
                        fw_cell,
                        bw_cell,
                        context_inputs,
                        sequence_length=context_seq_length,
                        time_major=True,
                        dtype=scope.dtype,
                        swap_memory=True)
                    fw_state, bw_state = context_state
                    context_state = tf.concat([fw_state, bw_state], axis=1)
                    return context_outputs, context_state
                else:
                    raise ValueError("Unknown encoder type: %s" %
                                     params.encoder_type)
예제 #2
0
    def __build_decoder_cell(self, params, mode, encoder_outputs,
                             encoder_state, keep_prob):
        cell = rnn_factory.create_cell(params.cell_type,
                                       params.hidden_units,
                                       self.num_layers,
                                       input_keep_prob=keep_prob,
                                       devices=self.round_robin.assign(
                                           self.num_layers))

        if mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width > 0:
            batch_size = self.batch_size * params.beam_width
            decoder_initial_state = tf.contrib.seq2seq.tile_batch(
                encoder_state, multiplier=params.beam_width)
            memory = tf.contrib.seq2seq.tile_batch(
                encoder_outputs, multiplier=params.beam_width)
            source_sequence_length = tf.contrib.seq2seq.tile_batch(
                self.iterator.source_sequence_lengths,
                multiplier=params.beam_width)
        else:
            batch_size = self.batch_size
            decoder_initial_state = encoder_state
            memory = encoder_outputs
            source_sequence_length = self.iterator.source_sequence_lengths

        try:
            attn_mechanism = attention_helper.create_attention_mechanism(
                params.attention_type, params.hidden_units, memory,
                source_sequence_length)

            alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width == 0
            cell = tf.contrib.seq2seq.AttentionWrapper(
                cell,
                attention_mechanism=attn_mechanism,
                attention_layer_size=params.hidden_units,
                alignment_history=alignment_history,
                output_attention=True,
                name="vanilla_attention")

            decoder_initial_state = cell.zero_state(
                batch_size, self.dtype).clone(cell_state=decoder_initial_state)
        except ValueError:
            pass

        return cell, decoder_initial_state
예제 #3
0
    def __build_context(self, params, encoder_results, keep_prob, device):
        with variable_scope.variable_scope("context"):
            with tf.device(device):
                context_seq_length = tf.fill([self.batch_size], self.num_turns)

                if params.context_direction == 'backward':
                    context_inputs = tf.stack(
                        [state for _, state in reversed(encoder_results)],
                        axis=0)
                else:
                    context_inputs = tf.stack(
                        [state for _, state in encoder_results], axis=0)

                # message_attention = attention_helper.create_attention_mechanism(params.attention_type,
                #                                                                 params.hidden_units,
                #                                                                 context_inputs)

                cell = rnn_factory.create_cell(params.cell_type,
                                               params.hidden_units,
                                               num_layers=1,
                                               input_keep_prob=keep_prob,
                                               devices=[device])

                # cell = tf.contrib.seq2seq.AttentionWrapper(
                #     cell,
                #     msg_attn_mechanism,
                #     attention_layer_size=params.hidden_units,
                #     alignment_history=False,
                #     output_attention=True,
                #     name="message_attention")

                context_outputs, context_state = tf.nn.dynamic_rnn(
                    cell,
                    inputs=context_inputs,
                    sequence_length=context_seq_length,
                    time_major=True,
                    dtype=self.dtype,
                    swap_memory=True)
                return context_outputs, context_state
예제 #4
0
    def __build_decoder_cell(self, params, context_outputs, context_state,
                             input_keep_prob, device):
        cell = rnn_factory.create_cell(params.cell_type,
                                       params.hidden_units,
                                       num_layers=1,
                                       input_keep_prob=input_keep_prob,
                                       devices=[device])

        topical_embeddings = tf.nn.embedding_lookup(self.embeddings,
                                                    self.iterator.topic)

        max_topic_length = tf.reduce_max(self.iterator.topic_sequence_length)

        expanded_context_state = tf.tile(tf.expand_dims(context_state, axis=1),
                                         [1, max_topic_length, 1])
        topical_embeddings = tf.concat(
            [expanded_context_state, topical_embeddings], axis=2)

        context_sequence_length = tf.fill([self.batch_size], self.num_turns)
        batch_majored_context_outputs = tf.transpose(context_outputs,
                                                     [1, 0, 2])

        if self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width > 0:
            batch_size = self.batch_size * params.beam_width

            decoder_initial_state = tf.contrib.seq2seq.tile_batch(
                context_state, multiplier=params.beam_width)

            memory = tf.contrib.seq2seq.tile_batch(
                batch_majored_context_outputs, multiplier=params.beam_width)
            topical_embeddings = tf.contrib.seq2seq.tile_batch(
                topical_embeddings, multiplier=params.beam_width)
            context_sequence_length = tf.contrib.seq2seq.tile_batch(
                context_sequence_length, multiplier=params.beam_width)
            topic_sequence_length = tf.contrib.seq2seq.tile_batch(
                self.iterator.topic_sequence_length,
                multiplier=params.beam_width)
        else:
            batch_size = self.batch_size
            decoder_initial_state = context_state
            memory = batch_majored_context_outputs
            topic_sequence_length = self.iterator.topic_sequence_length

        context_attention = attention_helper.create_attention_mechanism(
            params.attention_type, params.hidden_units, memory,
            context_sequence_length)

        topical_attention = attention_helper.create_attention_mechanism(
            params.attention_type, params.hidden_units, topical_embeddings,
            topic_sequence_length)

        alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width == 0
        cell = tf.contrib.seq2seq.AttentionWrapper(
            cell,
            attention_mechanism=(context_attention, topical_attention),
            attention_layer_size=(params.hidden_units, params.hidden_units),
            alignment_history=alignment_history,
            output_attention=True,
            name="joint_attention")

        decoder_initial_state = cell.zero_state(
            batch_size, self.dtype).clone(cell_state=decoder_initial_state)

        return cell, decoder_initial_state
예제 #5
0
    def __build_encoder(self, params, keep_prob, device):
        encoder_cell = {}

        if params.encoder_type == "uni":
            log.print_out("  build unidirectional encoder")
            encoder_cell['uni'] = rnn_factory.create_cell(
                params.cell_type,
                params.hidden_units,
                num_layers=1,
                input_keep_prob=keep_prob,
                devices=[device])
        elif params.encoder_type == "bi":
            log.print_out("  build bidirectional encoder")
            encoder_cell['fw'] = rnn_factory.create_cell(
                params.cell_type,
                params.hidden_units,
                num_layers=1,
                input_keep_prob=keep_prob,
                devices=[device])
            encoder_cell['bw'] = rnn_factory.create_cell(
                params.cell_type,
                params.hidden_units,
                num_layers=1,
                input_keep_prob=keep_prob,
                devices=[device])
        else:
            raise ValueError("Unknown encoder type: '%s'" %
                             params.encoder_type)

        encoding_devices = self.round_robin.assign(self.num_turns)

        encoder_results = []

        for t in range(self.num_turns):
            scope_name = "encoder%d" % t if params.disable_encoder_var_sharing else "encoder"
            with variable_scope.variable_scope(scope_name) as scope:
                if t > 0 and not params.disable_encoder_var_sharing:
                    scope.reuse_variables()
                with tf.device(encoding_devices[t]):
                    encoder_embedded_inputs = tf.nn.embedding_lookup(
                        params=self.embeddings, ids=self.iterator.sources[t])

                    if params.encoder_type == "bi":
                        encoder_outputs, states = tf.nn.bidirectional_dynamic_rnn(
                            encoder_cell['fw'],
                            encoder_cell['bw'],
                            inputs=encoder_embedded_inputs,
                            dtype=self.dtype,
                            sequence_length=self.iterator.
                            source_sequence_lengths[t],
                            swap_memory=True)

                        fw_state, bw_state = states
                        encoder_state = tf.concat([fw_state, bw_state], axis=1)
                    else:
                        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                            encoder_cell['uni'],
                            inputs=encoder_embedded_inputs,
                            sequence_length=self.iterator.
                            source_sequence_lengths[t],
                            dtype=self.dtype,
                            swap_memory=True,
                            scope=scope)

                    # msg_attn_mechanism = attention_helper.create_attention_mechanism(
                    #     params.attention_type,
                    #     params.hidden_units,
                    #     encoder_outputs,
                    #     self.iterator.source_sequence_lengths[t])

                    encoder_results.append((encoder_outputs, encoder_state))

        return encoder_results
예제 #6
0
    def __build_decoder_cell(self, params, encoder_outputs, encoder_state,
                             keep_prob):
        cell = rnn_factory.create_cell(params.cell_type,
                                       params.hidden_units,
                                       self.num_layers,
                                       input_keep_prob=keep_prob,
                                       devices=self.round_robin.assign(
                                           self.num_layers))

        topical_embeddings = tf.nn.embedding_lookup(self.embeddings,
                                                    self.iterator.topic)

        max_topic_length = tf.reduce_max(self.iterator.topic_sequence_length)

        aggregated_state = encoder_state
        if isinstance(encoder_state, tuple):
            aggregated_state = encoder_state[0]
            for state in encoder_state[1:]:
                aggregated_state = tf.concat([aggregated_state, state], axis=1)

        if isinstance(encoder_outputs, tuple):
            aggregated_outputs = encoder_outputs[0]
            for output in encoder_outputs[1:]:
                aggregated_outputs = tf.concat([aggregated_outputs, output],
                                               axis=1)

            encoder_outputs = aggregated_outputs

        expanded_encoder_state = tf.tile(
            tf.expand_dims(aggregated_state, axis=1), [1, max_topic_length, 1])
        topical_embeddings = tf.concat(
            [expanded_encoder_state, topical_embeddings], axis=2)

        if self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width > 0:
            batch_size = self.batch_size * params.beam_width

            if isinstance(encoder_state, tuple):
                decoder_initial_state = tuple([
                    tf.contrib.seq2seq.tile_batch(state,
                                                  multiplier=params.beam_width)
                    for state in encoder_state
                ])
            else:
                decoder_initial_state = tf.contrib.seq2seq.tile_batch(
                    encoder_state, multiplier=params.beam_width)

            memory = tf.contrib.seq2seq.tile_batch(
                encoder_outputs, multiplier=params.beam_width)
            topical_embeddings = tf.contrib.seq2seq.tile_batch(
                topical_embeddings, multiplier=params.beam_width)
            source_sequence_length = tf.contrib.seq2seq.tile_batch(
                self.iterator.source_sequence_lengths,
                multiplier=params.beam_width)
            topic_sequence_length = tf.contrib.seq2seq.tile_batch(
                self.iterator.topic_sequence_length,
                multiplier=params.beam_width)
        else:
            batch_size = self.batch_size
            decoder_initial_state = encoder_state
            memory = encoder_outputs
            source_sequence_length = self.iterator.source_sequence_lengths
            topic_sequence_length = self.iterator.topic_sequence_length

        message_attention = attention_helper.create_attention_mechanism(
            params.attention_type, params.hidden_units, memory,
            source_sequence_length)

        topical_attention = attention_helper.create_attention_mechanism(
            params.attention_type, params.hidden_units, topical_embeddings,
            topic_sequence_length)

        alignment_history = self.mode == tf.contrib.learn.ModeKeys.INFER and params.beam_width == 0
        cell = tf.contrib.seq2seq.AttentionWrapper(
            cell,
            attention_mechanism=(message_attention, topical_attention),
            attention_layer_size=(params.hidden_units, params.hidden_units),
            alignment_history=alignment_history,
            output_attention=True,
            name="joint_attention")

        decoder_initial_state = cell.zero_state(
            batch_size, self.dtype).clone(cell_state=decoder_initial_state)

        return cell, decoder_initial_state
예제 #7
0
    def __build_encoder(self, params, keep_prob):
        with variable_scope.variable_scope("encoder"):
            iterator = self.iterator
            encoder_embedded_inputs = tf.nn.embedding_lookup(
                params=self.embeddings, ids=iterator.sources)

            if params.encoder_type == "uni":
                log.print_out(
                    "  build unidirectional encoder num_layers = %d" %
                    params.num_layers)
                cell = rnn_factory.create_cell(params.cell_type,
                                               params.hidden_units,
                                               self.num_layers,
                                               input_keep_prob=keep_prob,
                                               devices=self.round_robin.assign(
                                                   self.num_layers))

                encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                    cell,
                    inputs=encoder_embedded_inputs,
                    sequence_length=iterator.source_sequence_lengths,
                    dtype=self.dtype,
                    swap_memory=True)
                return encoder_outputs, encoder_state
            elif params.encoder_type == "bi":
                num_bi_layers = int(params.num_layers / 2)
                log.print_out("  build bidirectional encoder num_layers = %d" %
                              params.num_layers)

                fw_cell = rnn_factory.create_cell(
                    params.cell_type,
                    params.hidden_units,
                    num_bi_layers,
                    input_keep_prob=keep_prob,
                    devices=self.round_robin.assign(num_bi_layers))
                bw_cell = rnn_factory.create_cell(
                    params.cell_type,
                    params.hidden_units,
                    num_bi_layers,
                    input_keep_prob=keep_prob,
                    devices=self.round_robin.assign(
                        num_bi_layers,
                        self.device_manager.num_available_gpus() - 1))

                encoder_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn(
                    fw_cell,
                    bw_cell,
                    encoder_embedded_inputs,
                    dtype=self.dtype,
                    sequence_length=iterator.source_sequence_lengths,
                    swap_memory=True)

                if num_bi_layers == 1:
                    encoder_state = bi_state
                else:
                    # alternatively concat forward and backward states
                    encoder_state = []
                    for layer_id in range(num_bi_layers):
                        encoder_state.append(bi_state[0][layer_id])  # forward
                        encoder_state.append(bi_state[1][layer_id])  # backward
                    encoder_state = tuple(encoder_state)

                return encoder_outputs, encoder_state
            else:
                raise ValueError("Unknown encoder type: %s" %
                                 params.encoder_type)
예제 #8
0
파일: hred_model.py 프로젝트: shark-3/THRED
    def __build_decoder(self, params, mode, context_outputs, context_state,
                        input_keep_prob, device):

        iterator = self.iterator

        decoder_cell = rnn_factory.create_cell(params.cell_type,
                                               params.decoder_hidden_units,
                                               num_layers=1,
                                               input_keep_prob=input_keep_prob,
                                               devices=[device])

        with variable_scope.variable_scope("decoder") as scope:
            with tf.device(device):
                initial_state = context_outputs[-1]

                if mode != tf.contrib.learn.ModeKeys.INFER:
                    # decoder_emp_inp: [max_time, batch_size, num_units]
                    decoder_emb_inp = tf.nn.embedding_lookup(
                        self.embeddings, iterator.target_input)

                    # Helper
                    if self.sampling_probability == 0.0:
                        helper = tf.contrib.seq2seq.TrainingHelper(
                            decoder_emb_inp, iterator.target_sequence_length)
                    else:
                        helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                            decoder_emb_inp, iterator.target_sequence_length,
                            self.embeddings, self.sampling_probability)

                    # Decoder
                    my_decoder = tf.contrib.seq2seq.BasicDecoder(
                        decoder_cell, helper, initial_state)

                    # Dynamic decoding
                    outputs, final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
                        my_decoder, swap_memory=True, scope=scope)

                    sample_id = outputs.sample_id

                    # Note: there's a subtle difference here between train and inference.
                    # We could have set output_layer when create my_decoder
                    #   and shared more code between train and inference.
                    # We chose to apply the output_layer to all timesteps for speed:
                    #   10% improvements for small models & 20% for larger ones.
                    # If memory is a concern, we should apply output_layer per timestep.
                    logits = self.output_layer(outputs.rnn_output)

                ### Inference
                else:
                    beam_width = params.beam_width
                    start_tokens = tf.fill([self.batch_size], vocab.SOS_ID)
                    end_token = vocab.EOS_ID

                    maximum_iterations = self._get_decoder_max_iterations(
                        params)

                    if beam_width > 0:
                        initial_state = tf.contrib.seq2seq.tile_batch(
                            context_outputs[-1], multiplier=params.beam_width)

                        my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                            cell=decoder_cell,
                            embedding=self.embeddings,
                            start_tokens=start_tokens,
                            end_token=end_token,
                            initial_state=initial_state,
                            beam_width=beam_width,
                            output_layer=self.output_layer,
                            length_penalty_weight=params.length_penalty_weight)
                    else:
                        # Helper
                        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                            self.embeddings, start_tokens, end_token)

                        # Decoder
                        my_decoder = tf.contrib.seq2seq.BasicDecoder(
                            decoder_cell,
                            helper,
                            initial_state,
                            output_layer=self.
                            output_layer  # applied per timestep
                        )

                    # Dynamic decoding
                    outputs, final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
                        my_decoder,
                        maximum_iterations=maximum_iterations,
                        swap_memory=True,
                        scope=scope)

                    if beam_width > 0:
                        logits = tf.no_op()
                        sample_id = outputs.predicted_ids
                    else:
                        logits = outputs.rnn_output
                        sample_id = outputs.sample_id

        return logits, sample_id, final_decoder_state