コード例 #1
0
ファイル: cvae_run.py プロジェクト: ArponKundu/UIU_MojiTalk
                               start_i,
                               end_i,
                               batch_size,
                               permutate=False)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    """init params"""
    sess.run(tf.global_variables_initializer())
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    saver = tf.train.Saver(var_list=train_vars, max_to_keep=100)
    if not is_seq2seq:
        saver.restore(sess,
                      "seq2seq/07-17_05-49-50/breakpoints/at_step_18000.ckpt")
    cvae.set_sess(sess)

    total_step = (num_epoch * len(train_data[0]) / batch_size)
    global_step = 1
    start_epoch = 1

    for epoch in range(start_epoch, num_epoch + 1):
        train_batches = batch_generator(train_data, start_i, end_i, batch_size)

        recon_l = []
        kl_l = []
        bow_l = []
        for batch in train_batches:
            """ TRAIN """
            if is_seq2seq:
                kl_weight = 1.
コード例 #2
0
p = Printer(log_f, index2word)

"""build graphs and init params"""
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
graph0 = tf.Graph()
with graph0.as_default():
    kl_ceiling = 0.48
    seq2seq = CVAE(vocab_size, embed_size, num_unit, latent_dim, emoji_dim, batch_size,
                   kl_ceiling, 1, decoder_layer,
                   start_i, end_i, beam_width, maximum_iterations, max_gradient_norm, lr, dropout, num_gpu,
                   cell_type,
                   is_seq2seq=False)
    sess0 = tf.Session(graph=graph0, config=config)
    sess0.run(tf.global_variables_initializer())
    seq2seq.set_sess(sess0)
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    saver0 = tf.train.Saver(var_list=train_vars, max_to_keep=25)
    saver0.restore(sess0, "cvae/07-17_15-51-04/breakpoints/at_step_36500.ckpt")
graph0.finalize()

graph1 = tf.Graph()
with graph1.as_default():
    classifier = EmojiClassifier(batch_size, vocab_size, emoji_num, embed_size, num_unit, num_gpu,
                                 lr=0.001, dropout=0.2, cell_type=tf.nn.rnn_cell.GRUCell)
    saver1 = tf.train.Saver()
    sess1 = tf.Session(graph=graph1, config=config)
    classifier.set_sess(sess1)
    classifier.set_emoji_index(emoji_b2s)
    saver1.restore(sess1, "classifier/07-16_14-33-58/breakpoints/best_test_loss.ckpt")
graph1.finalize()