コード例 #1
0
ファイル: rnn_decoder.py プロジェクト: buptpriswang/hasky
 def output(self, rnn_output):
     return melt.dense(rnn_output, self.w, self.v)
コード例 #2
0
    def __init__(self, is_training=True, is_predict=False):
        self.scope = 'rnn'
        self.is_training = is_training
        self.is_predict = is_predict

        assert not (FLAGS.decode_copy and FLAGS.decode_use_alignment)

        vocabulary.init()
        vocab_size = vocabulary.get_vocab_size()
        self.vocab_size = vocab_size

        self.end_id = vocabulary.end_id()

        self.start_id = None
        self.get_start_id()

        assert self.end_id != vocabulary.vocab.unk_id(
        ), 'input vocab generated without end id'

        self.emb_dim = emb_dim = FLAGS.emb_dim

        #--- for perf problem here exchange w_t and w https://github.com/tensorflow/tensorflow/issues/4138
        self.num_units = num_units = FLAGS.rnn_hidden_size
        with tf.variable_scope('output_projection'):
            self.w_t = melt.variable.get_weights_truncated(
                'w_t', [vocab_size, num_units], stddev=FLAGS.weight_stddev)
            #weights
            self.w = tf.transpose(self.w_t)
            #biases
            self.v = melt.variable.get_weights_truncated(
                'v', [vocab_size], stddev=FLAGS.weight_stddev)

        #TODO https://github.com/tensorflow/tensorflow/issues/6761  tf 1.0 will fail if not scope='rnn' the same as when using self.cell...

        self.cell = melt.create_rnn_cell(num_units=num_units,
                                         is_training=is_training,
                                         keep_prob=FLAGS.keep_prob,
                                         num_layers=FLAGS.num_layers,
                                         cell_type=FLAGS.cell)

        self.num_sampled = num_sampled = FLAGS.num_sampled if not (
            is_predict and FLAGS.predict_no_sample) else 0
        #self.softmax_loss_function is None means not need sample
        self.softmax_loss_function = None
        if FLAGS.gen_only:
            self.softmax_loss_function = melt.seq2seq.gen_sampled_softmax_loss_function(
                num_sampled,
                self.vocab_size,
                weights=self.w_t,
                biases=self.v,
                log_uniform_sample=FLAGS.log_uniform_sample,
                is_predict=self.is_predict,
                sample_seed=FLAGS.predict_sample_seed,
                vocabulary=vocabulary)

        if FLAGS.use_attention:
            print('----attention_option:', FLAGS.attention_option)
        if FLAGS.gen_copy_switch or FLAGS.gen_copy or FLAGS.copy_only:
            assert FLAGS.use_attention is True, 'must use attention if not gen_only mode seq2seq'
            FLAGS.gen_only = False
            if FLAGS.gen_copy_switch:
                print('-------gen copy switch mode!')
                FLAGS.gen_copy = False
                FLAGS.copy_only = False
            elif FLAGS.gen_copy:
                print('-------gen copy mode !')
                FLAGS.copy_only = False
            else:
                print('-------copy only mode !')
        else:
            print('--------gen only mode')

        #if use copy mode use score as alignment(no softmax)
        self.score_as_alignment = False if FLAGS.gen_only else True

        #gen only output_fn
        self.output_fn = lambda cell_output: melt.dense(
            cell_output, self.w, self.v)

        def copy_output(indices, batch_size, cell_output, cell_state):
            alignments = cell_state.alignments
            updates = alignments
            return tf.scatter_nd(indices,
                                 updates,
                                 shape=[batch_size, self.vocab_size])

        self.copy_output_fn = copy_output

        #one problem is big memory for large vocabulary
        def gen_copy_output(indices, batch_size, cell_output, cell_state):
            gen_logits = self.output_fn(cell_output)
            copy_logits = copy_output(indices, batch_size, cell_output,
                                      cell_state)

            if FLAGS.gen_copy_switch:
                gen_probability = cell_state.gen_probability
                #[batch_size, 1] * [batch_size, vocab_size]
                return gen_probability * tf.nn.softmax(gen_logits) + (
                    1 - gen_probability) * tf.nn.softmax(copy_logits)
            else:
                return gen_logits + copy_logits

        self.gen_copy_output_fn = gen_copy_output

        def gen_copy_output_train(time, indices, targets, sampled_values,
                                  batch_size, cell_output, cell_state):
            if self.softmax_loss_function is not None:
                labels = tf.slice(targets, [0, time], [-1, 1])

                sampled, true_expected_count, sampled_expected_count = sampled_values
                sampled_values = \
                  sampled, tf.slice(tf.reshape(true_expected_count, [batch_size, -1]), [0, time], [-1, 1]), sampled_expected_count

                sampled_ids, sampled_logits = melt.nn.compute_sampled_ids_and_logits(
                    weights=self.w_t,
                    biases=self.v,
                    labels=labels,
                    inputs=cell_output,
                    num_sampled=self.num_sampled,
                    num_classes=self.vocab_size,
                    sampled_values=sampled_values,
                    remove_accidental_hits=False)
                gen_indices = melt.batch_values_to_indices(
                    tf.to_int32(sampled_ids))
                gen_logits = tf.scatter_nd(gen_indices,
                                           sampled_logits,
                                           shape=[batch_size, self.vocab_size])
            else:
                gen_logits = self.output_fn(cell_output)

            copy_logits = copy_output(indices, batch_size, cell_output,
                                      cell_state)

            if FLAGS.gen_copy_switch:
                #gen_copy_switch == True.
                gen_probability = cell_state.gen_probability
                return gen_probability * tf.nn.softmax(gen_logits) + (
                    1 - gen_probability) * tf.nn.softmax(copy_logits)
            else:
                return gen_logits + copy_logits

        self.gen_copy_output_train_fn = gen_copy_output_train