示例#1
0
def generate_sample(checkpoint, length=-1):
    output = ''
    with tf.Session() as sess:
        # load the checkpoint
        new_saver = tf.train.import_meta_graph(checkpoint+'.meta')
        new_saver.restore(sess, checkpoint)
        # the initial input to feed the network
        x = util.encode_text('<t>')[0]
        x = np.array([[x]])  # shape [BATCHSIZE, SEQLEN] with BATCHSIZE=1 and SEQLEN=1

        # initial values
        y = x
        h = np.zeros([1, INTERNALSIZE * NLAYERS], dtype=np.float32)  # [ BATCHSIZE, INTERNALSIZE * NLAYERS]

        i = 0
        while (i < length) or (length == -1):
            # generate probabilities for the next character and the state of the network
            yo, h = sess.run(['Yo:0', 'H:0'], feed_dict={'X:0': y, 'pkeep:0': 1., 'Hin:0': h, 'batchsize:0': 1})

            # pick a character and decode it
            c = util.sample_from_probabilities(yo, topn=2)
            y = np.array([[c]])  # shape [BATCHSIZE, SEQLEN] with BATCHSIZE=1 and SEQLEN=1
            c = util.decode_character(c)
            # and return it
            i += 1
            gen_input = yield c
            # stop generating things if we recieve a signal to do so
            if gen_input is not None:
                raise StopIteration
示例#2
0
def write_bert_tf_example(simple_similar_source_indices, raw_article_sents, summary_text, corefs_str, doc_indices, article_lcs_paths_list, writer, dataset_name):
    tf_example = example_pb2.Example()
    source_indices_str = ';'.join([' '.join(str(i) for i in source_indices) for source_indices in simple_similar_source_indices])
    tf_example.features.feature['similar_source_indices'].bytes_list.value.extend([util.encode_text(source_indices_str)])
    for sent in raw_article_sents:
        s = sent.strip()
        tf_example.features.feature['raw_article_sents'].bytes_list.value.extend([util.encode_text(s)])
    if dataset_name == 'duc_2004':
        for summ_text in summary_text:
            tf_example.features.feature['summary_text'].bytes_list.value.extend([util.encode_text(summ_text)])
    else:
        tf_example.features.feature['summary_text'].bytes_list.value.extend([util.encode_text(summary_text)])
    if doc_indices is not None:
        tf_example.features.feature['doc_indices'].bytes_list.value.extend([util.encode_text(doc_indices)])
    if corefs_str is not None:
        tf_example.features.feature['corefs'].bytes_list.value.extend([corefs_str])
    if article_lcs_paths_list is not None:
        article_lcs_paths_list_str = '|'.join([';'.join([' '.join(str(i) for i in source_indices) for source_indices in article_lcs_paths]) for article_lcs_paths in article_lcs_paths_list])
        tf_example.features.feature['article_lcs_paths_list'].bytes_list.value.extend([util.encode_text(article_lcs_paths_list_str)])
    tf_example_str = tf_example.SerializeToString()
    str_len = len(tf_example_str)
    writer.write(struct.pack('q', str_len))
    writer.write(struct.pack('%ds' % str_len, tf_example_str))
示例#3
0
def make_example(article, abstracts, doc_indices, raw_article_sents, corefs, article_lcs_paths=None):
    tf_example = example_pb2.Example()
    tf_example.features.feature['article'].bytes_list.value.extend([util.encode_text(article)])
    for abstract in abstracts:
        if type(abstract) == list:
            tf_example.features.feature['abstract'].bytes_list.value.extend([util.encode_text(process_abstract(abstract))])
        else:
            tf_example.features.feature['abstract'].bytes_list.value.extend([util.encode_text(abstract)])
    if doc_indices is not None:
        if type(doc_indices) == list:
            doc_indices = ' '.join(doc_indices)
        tf_example.features.feature['doc_indices'].bytes_list.value.extend([util.encode_text(doc_indices)])
    if raw_article_sents is not None:
        for sent in raw_article_sents:
            tf_example.features.feature['raw_article_sents'].bytes_list.value.extend([util.encode_text(sent)])
    if corefs is not None:
        corefs_str = json.dumps(corefs)
        tf_example.features.feature['corefs'].bytes_list.value.extend([util.encode_text(corefs_str)])
    if article_lcs_paths is not None:
        article_lcs_paths_str = ';'.join([' '.join(str(i) for i in source_indices) for source_indices in article_lcs_paths])
        tf_example.features.feature['article_lcs_paths'].bytes_list.value.extend([util.encode_text(article_lcs_paths_str)])
    return tf_example
示例#4
0
def train(codetext, valitext):
    DISPLAY_FREQ = 50
    _50_BATCHES = DISPLAY_FREQ * BATCHSIZE * SEQLEN

    # init
    istate = np.zeros([BATCHSIZE, INTERNALSIZE*NLAYERS])  # initial zero input state
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    step = 0

    # training loop
    print('=== TRAINING ===')
    for x, y_, epoch in util.rnn_minibatch_sequencer(codetext, BATCHSIZE, SEQLEN, nb_epochs=50):

        # train on one minibatch
        feed_dict = {X: x, Y_: y_, Hin: istate, lr: learning_rate, pkeep: dropout_pkeep, batchsize: BATCHSIZE}
        _, y, ostate = sess.run([train_step, Y, H], feed_dict=feed_dict)

        # log training data for Tensorboard display a mini-batch of sequences (every 50 batches)
        if step % _50_BATCHES == 0:
            feed_dict = {X: x, Y_: y_, Hin: istate, pkeep: 1.0, batchsize: BATCHSIZE}  # no dropout for validation
            y, l, bl, acc, smm = sess.run([Y, seqloss, batchloss, accuracy, summaries], feed_dict=feed_dict)
            print('\n\nstep {} (epoch {}):'.format(step, epoch))
            print('  training:   loss={:.5f}, accuracy={:.5f}'.format(bl, acc))
            summary_writer.add_summary(smm, step)

        # run a validation step every 50 batches
        if step % _50_BATCHES == 0 and len(valitext) > 0:
            l_loss = []
            l_acc = []
            vali_state = np.zeros([BATCHSIZE, INTERNALSIZE*NLAYERS])
            for vali_x, vali_y, _ in util.rnn_minibatch_sequencer(valitext, BATCHSIZE, SEQLEN, 1):
                feed_dict = {X: vali_x, Y_: vali_y, Hin: vali_state, pkeep: 1.0,  # no dropout for validation
                             batchsize: BATCHSIZE}
                ls, acc, ostate = sess.run([batchloss, accuracy, H], feed_dict=feed_dict)
                l_loss.append(ls)
                l_acc.append(acc)
                vali_state = ostate
            # calculate average
            avg_summary = tf.Summary(value=[
                tf.Summary.Value(tag="batch_loss", simple_value=np.mean(l_loss)),
                tf.Summary.Value(tag="batch_accuracy", simple_value=np.mean(l_acc)),
            ])

            print('  validation: loss={:.5f}, accuracy={:.5f}'.format(ls, acc))
            # save validation data for Tensorboard
            validation_writer.add_summary(avg_summary, step)

        # display a short text generated with the current weights and biases (every 150 batches)
        if step // 3 % _50_BATCHES == 0:
            print('--- generated sample ---')
            ry = np.array([[util.encode_text('<t>')[0]]])
            rh = np.zeros([1, INTERNALSIZE * NLAYERS])
            for k in range(2000):
                ryo, rh = sess.run([Yo, H], feed_dict={X: ry, pkeep: 1.0, Hin: rh, batchsize: 1})
                rc = util.sample_from_probabilities(ryo, topn=3 if epoch <= 1 else 2)
                print(util.decode_character(rc), end="")
                ry = np.array([[rc]])
            print('\n--- end of generated sample ---'.format(step))

        # save a checkpoint (every 500 batches)
        if step // 10 % _50_BATCHES == 0:
            saved_file = saver.save(sess, 'checkpoints/rnn_train_' + timestamp, global_step=step)
            print("Saved file: " + saved_file)

        print('.', end='', flush=True)

        # loop state around
        istate = ostate
        step += BATCHSIZE * SEQLEN