def main():
    # config for training
    config = Config()
    config.use_bow = False

    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60
    valid_config.use_bow = False

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1
    test_config.use_bow = False

    pp(config)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir,
                           word2vec=FLAGS.word2vec_path,
                           word2vec_dim=config.embed_size)
    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    train_meta, valid_meta, test_meta = meta_corpus.get(
        "train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, valid_dial, test_dial = dial_corpus.get(
        "train"), dial_corpus.get("valid"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    # begin training
    sess_config = tf.ConfigProto(log_device_placement=False,
                                 allow_soft_placement=True)
    # sess_config.gpu_options.allow_growth = True
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.45
    with tf.Session(config=sess_config) as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w,
                                                    config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = KgRnnCVAE(sess,
                              config,
                              api,
                              log_dir=None if FLAGS.forward_only else log_dir,
                              forward=False,
                              scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = KgRnnCVAE(sess,
                                    valid_config,
                                    api,
                                    log_dir=None,
                                    forward=False,
                                    scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            test_model = KgRnnCVAE(sess,
                                   test_config,
                                   api,
                                   log_dir=None,
                                   forward=True,
                                   scope=scope)

        test_model.prepare_mul_ref()

        logger.info("Created computation graphs")
        if api.word2vec is not None and not FLAGS.forward_only:
            logger.info("Loaded word2vec")
            sess.run(model.embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "wb") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        logger.info("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            logger.info("Reading dm models parameters from %s" %
                        ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)

        if not FLAGS.forward_only:
            dm_checkpoint_path = os.path.join(
                ckp_dir, model.__class__.__name__ + ".ckpt")
            global_t = 1
            patience = 10  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            for epoch in range(config.max_epoch):
                logger.info(">> Epoch %d with lr %f" %
                            (epoch,
                             sess.run(model.learning_rate_cyc,
                                      {model.global_t: global_t})))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size,
                                          config.backward_size,
                                          config.step_size,
                                          shuffle=True)
                global_t, train_loss = model.train(
                    global_t,
                    sess,
                    train_feed,
                    update_limit=config.update_limit)

                # begin validation
                logger.record_tabular("Epoch", epoch)
                logger.record_tabular("Mode", "Val")
                valid_feed.epoch_init(valid_config.batch_size,
                                      valid_config.backward_size,
                                      valid_config.step_size,
                                      shuffle=False,
                                      intra_shuffle=False)
                valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed)

                logger.record_tabular("Epoch", epoch)
                logger.record_tabular("Mode", "Test")
                test_feed.epoch_init(valid_config.batch_size,
                                     valid_config.backward_size,
                                     valid_config.step_size,
                                     shuffle=False,
                                     intra_shuffle=False)
                valid_model.valid("ELBO_TEST", sess, test_feed)

                # test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                #                      test_config.step_size, shuffle=True, intra_shuffle=False)
                # test_model.test_mul_ref(sess, test_feed, num_batch=5)

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    sess.run(model.learning_rate_decay_op)

                if valid_loss < best_dev_loss:
                    if valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience,
                                       done_epoch * config.patient_increase)
                        dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        logger.info("Save model!!")
                        model.saver.save(sess, dm_checkpoint_path)
                    best_dev_loss = valid_loss

                    if (epoch % 3) == 2:
                        tmp_model_path = os.path.join(
                            ckp_dir,
                            model.__class__.__name__ + str(epoch) + ".ckpt")
                        model.saver.save(sess, tmp_model_path)

                if config.early_stop and patience <= done_epoch:
                    logger.info("!!Early stop due to run out of patience!!")
                    break
            logger.info("Best validation loss %f" % best_dev_loss)
            logger.info("Done training")
        else:
            # begin validation
            # begin validation
            valid_feed.epoch_init(valid_config.batch_size,
                                  valid_config.backward_size,
                                  valid_config.step_size,
                                  shuffle=False,
                                  intra_shuffle=False)
            valid_model.valid("ELBO_VALID", sess, valid_feed)

            test_feed.epoch_init(valid_config.batch_size,
                                 valid_config.backward_size,
                                 valid_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            valid_model.valid("ELBO_TEST", sess, test_feed)

            dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            test_feed.epoch_init(test_config.batch_size,
                                 test_config.backward_size,
                                 test_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            test_model.test_mul_ref(sess,
                                    test_feed,
                                    num_batch=None,
                                    repeat=5,
                                    dest=dest_f)
            dest_f.close()
Exemplo n.º 2
0
def main():
    # config for training
    config = Config()

    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1

    pp(config)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir,
                           word2vec=FLAGS.word2vec_path,
                           word2vec_dim=config.embed_size)
    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    train_meta, valid_meta, test_meta = meta_corpus.get(
        "train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, valid_dial, test_dial = dial_corpus.get(
        "train"), dial_corpus.get("valid"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    if FLAGS.forward_only or FLAGS.resume:
        log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path)
    else:
        log_dir = os.path.join(FLAGS.work_dir, "run" + str(int(time.time())))

    # begin training
    if True:
        scope = "model"
        model = KgRnnCVAE(config,
                          api,
                          log_dir=None if FLAGS.forward_only else log_dir,
                          scope=scope)

        print("Created computation graphs")
        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "wb") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        model.apply(lambda m: [
            torch.nn.init.uniform(p.data, -1.0 * config.init_w, config.init_w)
            for p in m.parameters()
        ])

        # Load word2vec weight
        if api.word2vec is not None and not FLAGS.forward_only:
            print("Loaded word2vec")
            model.embedding.weight.data.copy_(
                torch.from_numpy(np.array(api.word2vec)))
        model.embedding.weight.data[0].fill_(0)

        if ckpt:
            print("Reading dm models parameters from %s" % ckpt)
            model.load_state_dict(torch.load(ckpt))

        #

        # turn to cuda
        model.cuda()

        if not FLAGS.forward_only:
            dm_checkpoint_path = os.path.join(
                ckp_dir, model.__class__.__name__ + "-%d.pth")
            global_t = 1
            patience = 10  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            for epoch in range(config.max_epoch):
                print(">> Epoch %d with lr %f" % (epoch, model.learning_rate))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size,
                                          config.backward_size,
                                          config.step_size,
                                          shuffle=True)
                global_t, train_loss = model.train_model(
                    global_t, train_feed, update_limit=config.update_limit)

                # begin validation
                valid_feed.epoch_init(valid_config.batch_size,
                                      valid_config.backward_size,
                                      valid_config.step_size,
                                      shuffle=False,
                                      intra_shuffle=False)
                model.eval()
                valid_loss = model.valid_model("ELBO_VALID", valid_feed)

                test_feed.epoch_init(test_config.batch_size,
                                     test_config.backward_size,
                                     test_config.step_size,
                                     shuffle=True,
                                     intra_shuffle=False)
                model.test_model(test_feed, num_batch=5)
                model.train()

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    model.learning_rate_decay()

                if valid_loss < best_dev_loss:
                    if valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience,
                                       done_epoch * config.patient_increase)
                        dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        print("Save model!!")
                        torch.save(model.state_dict(),
                                   dm_checkpoint_path % (epoch))
                    best_dev_loss = valid_loss

                if config.early_stop and patience <= done_epoch:
                    print("!!Early stop due to run out of patience!!")
                    break
            print("Best validation loss %f" % best_dev_loss)
            print("Done training")
        else:
            # begin validation
            # begin validation
            valid_feed.epoch_init(valid_config.batch_size,
                                  valid_config.backward_size,
                                  valid_config.step_size,
                                  shuffle=False,
                                  intra_shuffle=False)
            model.eval()
            model.valid_model("ELBO_VALID", valid_feed)

            test_feed.epoch_init(valid_config.batch_size,
                                 valid_config.backward_size,
                                 valid_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            model.valid_model("ELBO_TEST", test_feed)

            dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            test_feed.epoch_init(test_config.batch_size,
                                 test_config.backward_size,
                                 test_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            model.test_model(test_feed, num_batch=None, repeat=10, dest=dest_f)
            model.train()
            dest_f.close()
Exemplo n.º 3
0
def main():
    ## random seeds
    seed = FLAGS.seed
    # tf.random.set_seed(seed)
    np.random.seed(seed)

    ## config for training
    config = Config()
    pid = PIDControl(FLAGS.exp_KL)
    
    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1

    pp(config)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size)
    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, valid_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("valid"), dial_corpus.get("test")
    
    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    if FLAGS.forward_only or FLAGS.resume:
        # log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path)
        log_dir = os.path.join(FLAGS.work_dir, FLAGS.model_name)
    else:
        log_dir = os.path.join(FLAGS.work_dir, FLAGS.model_name)

    
    ## begin training
    with tf.Session() as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, pid_control=pid, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, pid_control=pid, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, pid_control=pid, scope=scope)

        print("Created computation graphs")
        if api.word2vec is not None and not FLAGS.forward_only:
            print("Loaded word2vec")
            sess.run(model.embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "configure.log"), "wb") as f:
                f.write(pp(config, output=False))
        
        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        print("*******checkpoint path: ", ckp_dir)
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        ### save log when running
        if not FLAGS.forward_only:
            logfileName = "train.log"
        else:
            logfileName = "test.log"
        fw_log = open(os.path.join(log_dir, logfileName), "w")
        print("log directory >>> : ", os.path.join(log_dir, "run.log"))
        if not FLAGS.forward_only:
            print('--start training now---')
            dm_checkpoint_path = os.path.join(ckp_dir, model.__class__.__name__+ ".ckpt")
            global_t = 1
            patience = 20  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            pbar = tqdm(total = config.max_epoch)
            ## epoch start training
            for epoch in range(config.max_epoch):
                pbar.update(1)
                print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval()))

                ## begin training
                FLAGS.mode = 'train'
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size, config.backward_size,
                                          config.step_size, shuffle=True)
                global_t, train_loss = model.train(global_t, sess, train_feed, update_limit=config.update_limit)
                
                FLAGS.mode = 'valid'
                valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
                test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
                elbo, nll, ppl, au_count, kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed)
                print('middle test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss))
                fw_log.write('epoch:{} testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\
                            format(epoch, nll, ppl, au_count, kl_loss, elbo))
                fw_log.flush()
                
                '''
                ## begin validation
                FLAGS.mode = 'valid'
                valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                      valid_config.step_size, shuffle=False, intra_shuffle=False)
                valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed)

                ## test model
                FLAGS.mode = 'test'
                test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                     test_config.step_size, shuffle=True, intra_shuffle=False)
                test_model.test(sess, test_feed, num_batch=5)

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    sess.run(model.learning_rate_decay_op)

                if valid_loss < best_dev_loss:
                    if valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience, done_epoch * config.patient_increase)
                        dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        print("Save model!!")
                        model.saver.save(sess, dm_checkpoint_path, global_step=epoch)
                    best_dev_loss = valid_loss

                if config.early_stop and patience <= done_epoch:
                    print("!!Early stop due to run out of patience!!")
                    break
                    ## print("Best validation loss %f" % best_dev_loss)
                 '''
            print("Done training and save checkpoint")

            if FLAGS.save_model:
                print("Save model!!")
                model.saver.save(sess, dm_checkpoint_path, global_step=epoch)
            # begin validation
            print('--------after training to testing now-----')
            FLAGS.mode = 'test'
            # valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                #   valid_config.step_size, shuffle=False, intra_shuffle=False)
            # valid_model.valid("ELBO_VALID", sess, valid_feed)
            valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            
            test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            elbo, nll, ppl, au_count,kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed)

            print('final test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss))
            fw_log.write('Final testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\
                            format(nll, ppl, au_count, kl_loss, elbo))
            
            dest_f = open(os.path.join(log_dir, FLAGS.test_res), "wb")
            test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                 test_config.step_size, shuffle=False, intra_shuffle=False)
            test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f)
            dest_f.close()
            print("****testing done****")
        else:
            # begin validation
            # begin validation
            print('*'*89)
            print('--------testing now-----')
            print('*'*89)
            FLAGS.mode = 'test'
            valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            # valid_model.valid("ELBO_VALID", sess, valid_feed)

            test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            elbo, nll, ppl, au_count, kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed)

            print('final test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss))
            fw_log.write('Final testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\
                            format(nll, ppl, au_count, kl_loss, elbo))
            # dest_f = open(os.path.join(log_dir, FLAGS.test_res), "wb")
            # test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
            #                      test_config.step_size, shuffle=False, intra_shuffle=False)
            # test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f)
            # dest_f.close()
            print("****testing done****")
        fw_log.close()
Exemplo n.º 4
0
def main():
    # config for training
    config = Config()

    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1

    pp(config)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size)
    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    '''item in meta like:
     ([0.46, 0.6666666666666666, 1, 0], [0.48, 0.6666666666666666, 1, 0], 29)
    '''
    train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test")
    '''item in dial
    ([utt_id,], spearker_id, [dialogact_id, [,,,]])
    '''
    train_dial, valid_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("valid"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    if FLAGS.forward_only or FLAGS.resume:
        log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path)
    else:
        log_dir = os.path.join(FLAGS.work_dir, "run"+str(int(time.time())))

    # begin training
    with tf.Session() as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, scope=scope)

        print("Created computation graphs")
        if api.word2vec is not None and not FLAGS.forward_only:
            print("Loaded word2vec")
            sess.run(model.embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "wb") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)

        if not FLAGS.forward_only:
            dm_checkpoint_path = os.path.join(ckp_dir, model.__class__.__name__+ ".ckpt")
            global_t = 1
            patience = 10  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            for epoch in range(config.max_epoch):
                print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval()))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size, config.backward_size,
                                          config.step_size, shuffle=True)
                global_t, train_loss = model.train(global_t, sess, train_feed, update_limit=config.update_limit)

                # begin validation
                valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                      valid_config.step_size, shuffle=False, intra_shuffle=False)
                valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed)

                test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                     test_config.step_size, shuffle=True, intra_shuffle=False)
                test_model.test(sess, test_feed, num_batch=5)

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    sess.run(model.learning_rate_decay_op)

                if valid_loss < best_dev_loss:
                    if valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience, done_epoch * config.patient_increase)
                        dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        print("Save model!!")
                        model.saver.save(sess, dm_checkpoint_path, global_step=epoch)
                    best_dev_loss = valid_loss

                if config.early_stop and patience <= done_epoch:
                    print("!!Early stop due to run out of patience!!")
                    break
            print("Best validation loss %f" % best_dev_loss)
            print("Done training")
        elif not FLAGS.demo:
            # begin validation
            # begin validation
            valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            valid_model.valid("ELBO_VALID", sess, valid_feed)

            test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            valid_model.valid("ELBO_TEST", sess, test_feed)

            dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                 test_config.step_size, shuffle=False, intra_shuffle=False)
            test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f)
            dest_f.close()
        else:
            infer_api = Corpus(FLAGS.vocab, FLAGS.da_vocab, FLAGS.topic_vocab)
            while(1):
                demo_sent = raw_input('Please input your sentence:\t')
                if demo_sent == '' or demo_sent.isspace():
                    print('See you next time!')
                    break
                speaker = raw_input('Are you speaker A?: 1 or 0\t')
                topic = raw_input('what is your topic:\t')
                # demo_sent = 'Hello'
                # speaker = '1'
                # topic = 'MUSIC'
                meta, dial = infer_api.format_input(demo_sent, speaker, topic)
                infer_feed = SWDADataLoader("Infer", dial, meta, config)
                infer_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                 test_config.step_size, shuffle=False, intra_shuffle=False)
                pred_strs, pred_das = test_model.inference(sess, infer_feed, num_batch=None, repeat=3)
                if pred_strs and pred_das:
                    print('Here is your answer')
                    for da, answer in zip(pred_strs, pred_das):
                        print('{} >> {}'.format(da, answer))
Exemplo n.º 5
0
def main(model_type):
    # config for training
    config = Config()

    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 60

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 1

    pp(config)

    # which model to run
    if model_type == "kgcvae":
        model_class = KgRnnCVAE
        backward_size = config.backward_size
        config.use_hcf, valid_config.use_hcf, test_config.use_hcf = True
    elif model_type == "cvae":
        model_class = KgRnnCVAE
        backward_size = config.backward_size
        config.use_hcf, valid_config.use_hcf, test_config.use_hcf = False
    elif model_type == 'hierbaseline':
        model_class = HierBaseline
        backward_size = config.backward_size
    else:
        raise ValueError("This shouldn't happen.")

    # LDA Model
    ldamodel = LDAModel(config,
                        trained_model_path=FLAGS.lda_model_path,
                        id2word_path=FLAGS.id2word_path)

    # get data set
    api = SWDADialogCorpus(FLAGS.data_dir,
                           word2vec=FLAGS.word2vec_path,
                           word2vec_dim=config.embed_size,
                           vocab_dict_path=FLAGS.vocab_dict_path,
                           lda_model=ldamodel,
                           imdb=FLAGS.use_imdb)

    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    train_meta, valid_meta, test_meta = meta_corpus.get(
        "train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, valid_dial, test_dial = dial_corpus.get(
        "train"), dial_corpus.get("valid"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    # if you're testing an existing implementation or resuming training
    if FLAGS.forward_only or FLAGS.resume:
        log_dir = os.path.join(FLAGS.work_dir + "_" + FLAGS.model_type,
                               FLAGS.test_path)
    else:
        log_dir = os.path.join(FLAGS.work_dir + "_" + FLAGS.model_type,
                               "run" + str(int(time.time())))

    # begin training
    with tf.Session() as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w,
                                                    config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = model_class(
                sess,
                config,
                api,
                log_dir=None if FLAGS.forward_only else log_dir,
                forward=False,
                scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = model_class(sess,
                                      valid_config,
                                      api,
                                      log_dir=None,
                                      forward=False,
                                      scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            test_model = model_class(sess,
                                     test_config,
                                     api,
                                     log_dir=None,
                                     forward=True,
                                     scope=scope)

        print("Created computation graphs")
        if api.word2vec is not None and not FLAGS.forward_only:
            print("Loaded word2vec")
            sess.run(model.embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "wb") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            print("Reading dm models parameters from %s" %
                  ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)

        # if you're training a model
        if not FLAGS.forward_only:

            dm_checkpoint_path = os.path.join(
                ckp_dir, model.__class__.__name__ + ".ckpt")
            global_t = 1
            patience = 10  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf

            # train for a max of max_epoch's. saves the model after the epoch if it's some amount better than current best
            for epoch in range(config.max_epoch):
                print(">> Epoch %d with lr %f" %
                      (epoch, model.learning_rate.eval()))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size,
                                          backward_size,
                                          config.step_size,
                                          shuffle=True)

                global_t, train_loss = model.train(
                    global_t,
                    sess,
                    train_feed,
                    update_limit=config.update_limit)

                # begin validation and testing
                valid_feed.epoch_init(valid_config.batch_size,
                                      valid_config.backward_size,
                                      valid_config.step_size,
                                      shuffle=False,
                                      intra_shuffle=False)
                valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed)

                test_feed.epoch_init(test_config.batch_size,
                                     test_config.backward_size,
                                     test_config.step_size,
                                     shuffle=True,
                                     intra_shuffle=False)
                test_model.test(
                    sess, test_feed, num_batch=1
                )  #TODO change this batch size back to a reasonably large number

                done_epoch = epoch + 1
                # only save a models if the dev loss is smaller
                # Decrease learning rate if no improvement was seen over last 3 times.
                if config.op == "sgd" and done_epoch > config.lr_hold:
                    sess.run(model.learning_rate_decay_op)

                if True:  #valid_loss < best_dev_loss: # TODO this change makes the model always save. Change this back when corpus not trivial
                    if True:  #valid_loss <= dev_loss_threshold * config.improve_threshold:
                        patience = max(patience,
                                       done_epoch * config.patient_increase)
                        # dev_loss_threshold = valid_loss

                    # still save the best train model
                    if FLAGS.save_model:
                        print("Save model!!")
                        model.saver.save(sess,
                                         dm_checkpoint_path,
                                         global_step=epoch)
                    # best_dev_loss = valid_loss

                if config.early_stop and patience <= done_epoch:
                    print("!!Early stop due to run out of patience!!")
                    break
            # print("Best validation loss %f" % best_dev_loss)
            print("Done training")

        # else if you're just testing an existing model
        else:
            # begin validation
            valid_feed.epoch_init(valid_config.batch_size,
                                  valid_config.backward_size,
                                  valid_config.step_size,
                                  shuffle=False,
                                  intra_shuffle=False)
            valid_model.valid("ELBO_VALID", sess, valid_feed)

            test_feed.epoch_init(valid_config.batch_size,
                                 valid_config.backward_size,
                                 valid_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            valid_model.valid("ELBO_TEST", sess, test_feed)

            # begin testing
            dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            test_feed.epoch_init(test_config.batch_size,
                                 test_config.backward_size,
                                 test_config.step_size,
                                 shuffle=False,
                                 intra_shuffle=False)
            test_model.test(sess,
                            test_feed,
                            num_batch=None,
                            repeat=10,
                            dest=dest_f)
            dest_f.close()