Example #1
0
    def __init__(self, is_training=True, is_predict=False):
        super(RnnEncoder, self).__init__()
        self.is_training = is_training
        self.is_predict = is_predict

        vocabulary.init()

        if FLAGS.encoder_end_mark == '</S>':
            self.end_id = vocabulary.end_id()
        else:
            self.end_id = vocabulary.go_id(
            )  #NOTICE NUM_RESERVED_IDS must >= 3 TODO
        assert self.end_id != vocabulary.vocab.unk_id(
        ), 'input vocab generated without end id'

        create_rnn_cell = functools.partial(melt.create_rnn_cell,
                                            num_units=FLAGS.rnn_hidden_size,
                                            is_training=is_training,
                                            keep_prob=FLAGS.keep_prob,
                                            num_layers=FLAGS.num_layers,
                                            cell_type=FLAGS.cell)

        #follow models/textsum
        self.cell = create_rnn_cell(
            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123))
        if FLAGS.rnn_method == melt.rnn.EncodeMethod.bidirectional:
            self.bwcell = create_rnn_cell(
                initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
        else:
            self.bwcell = None
Example #2
0
def get_decoder_start_id():
    #start_id = vocabulary.start_id()
    start_id = None
    if not FLAGS.input_with_start_mark and FLAGS.add_text_start:
        if FLAGS.zero_as_text_start:
            start_id = 0
        elif FLAGS.go_as_text_start:
            start_id = vocabulary.go_id()
        else:
            start_id = vocabulary.start_id()
    return start_id
Example #3
0
 def get_start_id(self):
     start_id = None
     if not FLAGS.input_with_start_mark and FLAGS.add_text_start:
         if FLAGS.zero_as_text_start:
             start_id = 0
         elif FLAGS.go_as_text_start:
             start_id = vocabulary.go_id()
         else:
             start_id = vocabulary.start_id()
     self.start_id = start_id
     return start_id
Example #4
0
def get_decodes(shuffle_then_decode, dynamic_batch_length):
    vocabulary.init()

    global encoder_end_id
    if FLAGS.encoder_end_mark == '</S>':
        encoder_end_id = vocabulary.end_id()
    else:
        encoder_end_id = vocabulary.go_id(
        )  #NOTICE NUM_RESERVED_IDS must >= 3 TODO
    assert encoder_end_id != vocabulary.vocab.unk_id(
    ), 'input vocab generated without end id'

    if shuffle_then_decode:
        inputs = melt.shuffle_then_decode.inputs
        decode = lambda x: decode_examples(x, dynamic_batch_length)
    else:
        inputs = melt.decode_then_shuffle.inputs
        decode = lambda x: decode_example(x, dynamic_batch_length)
    return inputs, decode