Example #1
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

    args.max_len = 100
    args.batch_size = 64
    args.max_dec_len = 100
    args.display_info_step = 1000
    args.isPointer = False
    args.vocab_limit = 35000
    train_data_len = 209940
    args.diff_input = True
    print(args)

    dataloader = IWSLT(batch_size=args.batch_size, vocab_limit=args.vocab_limit, max_input_len=args.max_len, max_output_len=args.max_dec_len)
    params = {
        'vocab_size_encoder': len(dataloader.idx2token),
        'vocab_size': len(dataloader.word2idx),
        'word2idx': dataloader.word2idx,
        'idx2word': dataloader.idx2word,
        'idx2token': dataloader.idx2token}
    print('Vocab Size:', params['vocab_size'])

    model = VRAE(params)
    saver = tf.train.Saver()
    exp_path = "./saved/iwslt_seq2seq/"
    model_name = "seq2seq.ckpt"

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth=True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    EPOCH_STEPS = (train_data_len-1)//args.batch_size+1

    summary_writer = tf.summary.FileWriter(exp_path, sess.graph)
    restore_path = tf.train.latest_checkpoint(exp_path)
    if restore_path:
        saver.restore(sess, restore_path)
        last_train_step = int(restore_path.split("-")[-1]) % EPOCH_STEPS
        print("Model restore from file: %s, last train step: %d" % (restore_path, last_train_step))
    # summary_writer = tf.summary.FileWriter(exp_path)
    # saver.restore(sess, exp_path+model_name)

    for epoch in range(args.num_epoch):
        # dataloader.update_word_dropout()
        # print("\nWord Dropout")
        # dataloader.shuffle()
        # print("Data Shuffled", end='\n\n')
        batcher = dataloader.load_data()

        step = -1
        while True:
            try:
                # enc_inp, dec_inp_full, dec_out = next(dataloader.data_loader)
                (x_enc_inp, x_dec_inp_full, x_dec_out, y_enc_inp, y_dec_inp_full, y_dec_out), x_enc_inp_oovs, data_oovs, _ = next(batcher)
                # enc_inp, dec_inp_full, dec_out = dataloader.next_batch()
                enc_inp, dec_inp_full, dec_out = x_enc_inp, y_dec_inp_full, y_dec_out
                dec_inp = dataloader.update_word_dropout(dec_inp_full)
                step += 1
            except StopIteration:
                print("there are no more examples")
                break
            # print(step, "enc_inp.shape:", enc_inp.shape)
            # print(step, "dec_inp_full.shape:", dec_inp_full.shape)
            # print(step, "dec_out.shape:", dec_out.shape)

            log = model.train_session(sess, enc_inp, dec_inp, dec_out)

            # get the summaries and iteration number so we can write summaries to tensorboard
            summaries, train_step = log['summaries'], log['step']
            summary_writer.add_summary(summaries, train_step) # write the summaries
            if train_step % 100 == 0: # flush the summary writer every so often
                summary_writer.flush()

            if step % args.display_loss_step == 0:
                print("Step %d | [%d/%d] | [%d/%d]" % (log['step'], epoch+1, args.num_epoch, step, train_data_len//args.batch_size), end='')
                print(" | loss:%.3f" % (log['loss']))
        
            if step % args.display_info_step == 0:
                model.reconstruct(sess, enc_inp[-1], dec_out[-1])
                save_path = saver.save(sess, exp_path+model_name, global_step=train_step)
                print("Model saved in file: %s" % save_path)

        model.reconstruct(sess, enc_inp[-1], dec_out[-1])
        save_path = saver.save(sess, exp_path+model_name, global_step=train_step)
        print("Model saved in file: %s" % save_path)
Example #2
0
def main():
    ## CUDA
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)

    ## Parameters
    if args.exp == "NONE":
        args.exp = args.graph_type
    args.enc_max_len =  100
    args.dec_max_len = 100
    args.vocab_limit = 35000
    exp_path = "./saved/"+args.exp+"/"
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)
    model_name = args.model_name
    train_data_len = 209940
    args.data_len = train_data_len
    train_data_path = args.train_data
    EPOCH_STEPS = (train_data_len-1)//args.batch_size+1
    args.diff_input = True
    print(args)

    ## DataLoader
    dataloader = IWSLT(batch_size=args.batch_size, vocab_limit=args.vocab_limit, max_input_len=args.enc_max_len, max_output_len=args.dec_max_len)
    params = {
        'vocab_size': len(dataloader.word2idx),
        'vocab_size_encoder': len(dataloader.token2idx),
        'word2idx': dataloader.word2idx,
        'idx2word': dataloader.idx2word,
        'idx2token': dataloader.idx2token,
        'token2id': dataloader.token2idx,
        'loss_type': args.loss_type,
        'graph_type': args.graph_type}
    print('Vocab Size:', params['vocab_size'])

    ## ModelInit    
    model = VAESEQ(params)
    log_path = exp_path+"log.txt"
    LOGGER = open(log_path, "a")

    ## Session
    # load some parameters
    variables = tf.contrib.framework.get_variables_to_restore()
    # print(len(variables), end=",")
    # variables = [v for v in variables if not v.name.startswith("optimizer/transformer/trans_mlp/")]
    # for v in variables_to_resotre:
    #     print(type(v.name), v.name)
    print(len(variables))
    # end load
    saver = tf.train.Saver(variables)
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth=True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    summary_writer = tf.summary.FileWriter(exp_path, sess.graph)
    # tf.train.write_graph(sess.graph, './saved/vaeseq/', 'train.pbtxt')
    keep_on_train_flag = False
    restore_path = tf.train.latest_checkpoint(exp_path)
    if restore_path:
        keep_on_train_flag = True
        saver.restore(sess, restore_path)
        saver = tf.train.Saver() # new saver
        last_train_step = int(restore_path.split("-")[-1]) % EPOCH_STEPS
        print("Model restore from file: %s, last train step: %d" % (restore_path, last_train_step))
        LOGGER.write("Model restore from file: %s, last train step: %d\n" % (restore_path, last_train_step))

    # Train Mode
    x_log, y_log, t_log, log = None, None, None, None
    for epoch in range(args.num_epoch):
        # batcher = dataloader.load_data(fpath=train_data_path)
        batcher = dataloader.load_data()
        for step in tqdm(range(EPOCH_STEPS)):
            if keep_on_train_flag and step < last_train_step: continue
            if keep_on_train_flag and step == (last_train_step): keep_on_train_flag=False

            # get batch data
            try:
                (x_enc_inp, x_dec_inp_full, x_dec_out, y_enc_inp, y_dec_inp_full, y_dec_out), x_enc_inp_oovs, data_oovs, _ = next(batcher)
                x_dec_inp = dataloader.update_word_dropout(x_dec_inp_full)
                y_dec_inp = dataloader.update_word_dropout(y_dec_inp_full)
                max_oovs_len = max([len(oov) for oov in data_oovs]) if len(data_oovs) > 0 else 0
            except StopIteration:
                print("there are no more examples")
                break
                
            # for _ in range(2):
                # x_log = model.train_encoder(sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out)
                # y_log = model.train_decoder(sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out)
            # t_log = model.train_transformer(sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out)
            # t_log = model.merged_transformer_train(sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out)
            # log = model.merged_train(sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out)
            x_log = model.train_encoder(sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out)
            y_log = model.train_decoder(sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out)
            if step % 3 != 0:
                log = model.merged_seq_train(sess, x_enc_inp, x_dec_inp, x_dec_out, y_enc_inp, y_dec_inp, y_dec_out, x_enc_inp_oovs, max_oovs_len)
            # model.show_parameters(sess)

            # get the summaries and iteration number so we can write summaries to tensorboard
            train_step = summary_flush(x_log, y_log, t_log, log, summary_writer)

            if step % args.display_loss_step == 0:
                print("Step %d | [%d/%d] | [%d/%d]" % (train_step, epoch+1, args.num_epoch, step, train_data_len//args.batch_size), end='')
                LOGGER.write("Step %d | [%d/%d] | [%d/%d]" % (train_step, epoch+1, args.num_epoch, step, train_data_len//args.batch_size))
                show_loss(x_log, y_log, t_log, log, LOGGER)
        
            if step % args.display_info_step == 0 and step != 0:
                args.training = False
                save_path = saver.save(sess, exp_path+model_name, global_step=train_step)
                print("Model saved in file: %s" % save_path)
                # print("============= Show Encoder ===============")
                # model.show_encoder(sess, x_enc_inp[-1], x_dec_inp[-1], LOGGER)
                # print("============= Show Decoder ===============")
                # model.show_decoder(sess, y_enc_inp[-1], y_dec_inp[-1], LOGGER)
                # print("============= Show Sample ===============")
                for i in range(3):
                    model.show_sample(sess, x_enc_inp[i], y_dec_out[i], LOGGER)
                LOGGER.flush()
                args.training = True