コード例 #1
0
def main():

	parser = argparse.ArgumentParser()
	parser.add_argument('--batch_size', type=int, default=4)
	parser.add_argument('--emb_dim', type=int, default=200)
	parser.add_argument('--enc_hid_dim', type=int, default=128)
	parser.add_argument('--dec_hid_dim', type=int, default=256)
	parser.add_argument('--attn_size', type=int, default=200)
	parser.add_argument('--epochs', type=int, default=12)
	parser.add_argument('--learning_rate', type=float, default=2.5e-4)
	parser.add_argument('--dataset_path', type=str, default='../data/Maluuba/')
	parser.add_argument('--glove_path', type=str, default='../data/')
	parser.add_argument('--checkpoint', type=str, default="./trainDir/")
	config = parser.parse_args()

	DEVICE = "/gpu:0"

	logging.info("Loading Data")

	handler = DataHandler(
				emb_dim = config.emb_dim,
				batch_size = config.batch_size,
				train_path = config.dataset_path + "train.json",
				val_path = config.dataset_path + "val.json",
				test_path = config.dataset_path + "test.json",
				vocab_path = "./vocab.json",
				entities_path = config.dataset_path + "entities.json",
				glove_path = config.glove_path)

	logging.info("Loading Architecture")

	model = DialogueModel(
				device = DEVICE,
				batch_size = config.batch_size,
				inp_vocab_size = handler.input_vocab_size,
				out_vocab_size = handler.output_vocab_size,
				generate_size = handler.generate_vocab_size,
				emb_init = handler.emb_init,
				result_keys_vector = handler.result_keys_vector,
				emb_dim = config.emb_dim,
				enc_hid_dim = config.enc_hid_dim,
				dec_hid_dim = config.dec_hid_dim,
				attn_size = config.attn_size)

	logging.info("Loading Trainer")

	trainer = Trainer(
				model=model,
				handler=handler,
				ckpt_path="./trainDir/",
				num_epochs=config.epochs,
				learning_rate = config.learning_rate)

	trainer.trainData()
コード例 #2
0
def main():

	BATCH_SIZE = 32
	EMB_DIM = 200
	ENC_HID_DIM = 128
	DEC_HID_DIM = 256
	ATTN_SIZE = 200
	NUM_EPOCHS = 25
	DROPOUT = 0.75
	LR = 2.5e-4
	DEVICE = "/gpu:0"

	logging.info("Loading Data")

	handler = DataHandler(
				emb_dim = EMB_DIM,
				batch_size = BATCH_SIZE,
				train_path = "../data/Camrest/train.json",
				val_path = "../data/Camrest/val.json",
				test_path = "../data/Camrest/test.json",
				vocab_path = "./vocab.json")

	logging.info("Loading Architecture")

	model = DialogueModel(
				device = DEVICE,
				batch_size = BATCH_SIZE,
				inp_vocab_size = handler.input_vocab_size,
				out_vocab_size = handler.output_vocab_size,
				generate_size = handler.generate_vocab_size,
				emb_init = handler.emb_init,
				emb_dim = EMB_DIM,
				enc_hid_dim = ENC_HID_DIM,
				dec_hid_dim = DEC_HID_DIM,
				attn_size = ATTN_SIZE,
				dropout_keep_prob = DROPOUT)

	logging.info("Loading Trainer")

	trainer = Trainer(
				model=model,
				handler=handler,
				ckpt_path="./trainDir/",
				num_epochs=NUM_EPOCHS,
				learning_rate = LR)

	trainer.trainData()
コード例 #3
0
ファイル: infer.py プロジェクト: vunb/chatbot-1
def main(_):
    config = cPickle.load(open(FLAGS.logdir + "/hyperparams.pkl", 'rb'))
    pp.pprint(config)

    try:
        # pre-trained chars embedding
        emb = np.load("./data/emb.npy")
        chars = cPickle.load(open("./data/vocab.pkl", 'rb'))
        vocab_size, emb_size = np.shape(emb)
        data_loader = TextLoader('./data', 1, chars)
    except Exception:
        data_loader = TextLoader('./data', 1)
        emb_size = config["emb_size"]
        vocab_size = data_loader.vocab_size

    checkpoint = FLAGS.checkpoint + '/model.ckpt'

    model = DialogueModel(batch_size=1,
                          max_seq_length=data_loader.seq_length,
                          vocab_size=vocab_size,
                          pad_token_id=0,
                          unk_token_id=UNK_ID,
                          emb_size=emb_size,
                          memory_size=config["memory_size"],
                          keep_prob=config["keep_prob"],
                          learning_rate=config["learning_rate"],
                          grad_clip=config["grad_clip"],
                          temperature=config["temperature"],
                          infer=True)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init)

        if len(glob(checkpoint + "*")) > 0:
            saver.restore(sess, checkpoint)
        else:
            print("No model found!")
            return

        ## -- debug --
        #np.set_printoptions(threshold=np.inf)
        #for v in tf.trainable_variables():
        #  print(v.name)
        #  print(sess.run(v))
        #  print()
        #return

        while True:
            try:
                input_ = input('in> ')
            except EOFError:
                print("\nBye!")
                break

            input_ids, input_len = data_loader.parse_input(input_)

            feed = {
                model.input_data: np.expand_dims(input_ids, 0),
                model.input_lengths: [input_len]
            }

            output_ids, state = sess.run([model.output_ids, model.final_state],
                                         feed_dict=feed)

            print(data_loader.compose_output(output_ids[0]))
コード例 #4
0
def main():
    args = get_args()
    setup_seed(args["seed"])
    assert args["max_steps"] > 0 or args["num_epochs"] > 0
    output_dir = os.path.join("output", args["name"])
    data_dir = os.path.join("data", args["dataset"])
    args["output_dir"] = output_dir
    args["data_dir"] = data_dir
    while not os.path.exists(output_dir):
        if args["local_rank"] in [-1, 0]:
            os.mkdir(output_dir)
    logger = create_logger(os.path.join(output_dir, 'train.log'), local_rank=args["local_rank"])
    if args["local_rank"] in [-1, 0]:
        logger.info(args)
        with open(os.path.join(output_dir, "args.json"), mode="w") as f:
            json.dump(args, f)
    # code for distributed training
    if args["local_rank"] != -1:
        device = torch.device("cuda:{}".format(args["local_rank"]))
        torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=4)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() and not args["no_cuda"] else "cpu")
    
    # slot & gate
    with open(os.path.join(data_dir, "slot_map.json")) as f:
        slot_map = json.load(f)
    with open(os.path.join(data_dir, "gate_label.txt")) as f:
        gate_label = {line.strip(): i for i, line in enumerate(f)}
    
    if args["encoder"] == "bert":
        if len(args["special_tokens"]) > 0 and os.path.exists(args["special_tokens"]):
            with open(args["special_tokens"]) as f:
                special_tokens = f.read().strip().split("\n")
            tokenizer = BertTokenizer.from_pretrained(args["pre_model"], additional_special_tokens=special_tokens)
        else:
            tokenizer = BertTokenizer.from_pretrained(args["pre_model"])
        sp_ids = get_special_ids(tokenizer)
        model = DialogueModel(
            args["pre_model"], 0, 0, len(slot_map), len(gate_label), args["n_layer"],
            args["n_head"], args["dropout"], args["pre_layer_norm"], device, sp_ids["pad_id"])
    else:
        tokenizer = Tokenizer(os.path.join(data_dir, "vocab.txt"), True)
        sp_ids = get_special_ids(tokenizer)
        model = DialogueModel(
            args["encoder"], len(tokenizer), args["hidden_size"], len(slot_map), len(gate_label), 
            args["n_layer"], args["n_head"], args["dropout"], args["pre_layer_norm"], device, sp_ids["pad_id"])

    # train_dataset
    train_pkl = os.path.join(data_dir, "train_dials_{}.pkl".format(len(tokenizer)))
    if os.path.exists(train_pkl):
        train_data = train_pkl
        logger.info("load training cache from {}".format(train_pkl))
    else:
        train_data, domain_counter = get_data(os.path.join(data_dir, "train_dials.json"))
        logger.info("Traning domain_counter: {}".format(domain_counter))
    train_dataset = DialogDataset(
        train_data, tokenizer, slot_map, gate_label, args["max_seq_len"], args["max_resp_len"], True,
        sp_ids["user_type_id"], sp_ids["sys_type_id"], sp_ids["belief_type_id"],
        sp_ids["pad_id"], sp_ids["eos_id"], sp_ids["cls_id"], sp_ids["belief_sep_id"]
    )
    if not os.path.exists(train_pkl) and args["local_rank"] in [-1, 0]:
        with open(train_pkl, mode="wb") as f:
            pickle.dump(train_dataset.data, f)
        logger.info("save training cache to {}".format(train_pkl))
    # test_dataset
    test_pkl = os.path.join(data_dir, "test_dials_{}.pkl".format(len(tokenizer)))
    if os.path.exists(test_pkl):
        test_data = test_pkl
        logger.info("load test cache from {}".format(test_pkl))
    else:
        test_data, domain_counter = get_data(os.path.join(data_dir, "test_dials.json"))
        logger.info("Test domain_counter: {}".format(domain_counter))
    test_dataset = DialogDataset(
        test_data, tokenizer, slot_map, gate_label, args["max_seq_len"], args["max_resp_len"], False,
        sp_ids["user_type_id"], sp_ids["sys_type_id"], sp_ids["belief_type_id"],
        sp_ids["pad_id"], sp_ids["eos_id"], sp_ids["cls_id"], sp_ids["belief_sep_id"]
    )
    if not os.path.exists(test_pkl) and args["local_rank"] in [-1, 0]:
        with open(test_pkl, mode="wb") as f:
            pickle.dump(test_dataset.data, f)
        logger.info("save test cache to {}".format(test_pkl))

    trainer = Trainer(model, tokenizer, sp_ids, slot_map, gate_label, train_dataset, test_dataset, args, logger, device)
    if args["local_rank"] in [-1, 0]:
        logger.info("Start training")

    for epoch in range(1, args["num_epochs"]):
        logger.info("Epoch {} start, Cur step: {}".format(epoch, trainer.total_step))
        total_step = trainer.train(args["max_steps"])
        if total_step > args["max_steps"]:
            logger.info("Reach the max steps")
            break
コード例 #5
0
ファイル: train.py プロジェクト: vunb/chatbot-1
def main(_):
    pp.pprint(FLAGS.__flags)
    emb = None

    try:
        # pre-trained chars embedding
        emb = np.load("./data/emb.npy")
        chars = cPickle.load(open("./data/vocab.pkl", 'rb'))
        vocab_size, emb_size = np.shape(emb)
        data_loader = TextLoader('./data', FLAGS.batch_size, chars)
    except Exception:
        data_loader = TextLoader('./data', FLAGS.batch_size)
        emb_size = FLAGS.emb_size
        vocab_size = data_loader.vocab_size

    model = DialogueModel(batch_size=FLAGS.batch_size,
                          max_seq_length=data_loader.seq_length,
                          vocab_size=vocab_size,
                          pad_token_id=0,
                          unk_token_id=UNK_ID,
                          emb_size=emb_size,
                          memory_size=FLAGS.memory_size,
                          keep_prob=FLAGS.keep_prob,
                          learning_rate=FLAGS.learning_rate,
                          grad_clip=FLAGS.grad_clip,
                          temperature=FLAGS.temperature,
                          infer=False)

    summaries = tf.summary.merge_all()

    init = tf.global_variables_initializer()

    # save hyper-parameters
    cPickle.dump(FLAGS.__flags, open(FLAGS.logdir + "/hyperparams.pkl", 'wb'))

    checkpoint = FLAGS.checkpoint + '/model.ckpt'
    count = 0

    saver = tf.train.Saver()

    with tf.Session() as sess:
        summary_writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)

        sess.run(init)

        if len(glob(checkpoint + "*")) > 0:
            saver.restore(sess, checkpoint)
            print("Model restored!")
        else:
            # load embedding
            if emb is not None:
                sess.run([], {model.embedding: emb})
            print("Fresh variables!")

        current_step = 0
        count = 0

        for e in range(FLAGS.num_epochs):
            data_loader.reset_batch_pointer()
            state = None

            # iterate by batch
            for _ in range(data_loader.num_batches):
                x, y, input_lengths, output_lengths = data_loader.next_batch()

                if (current_step + 1) % 10 != 0:
                    res = model.step(sess, x, y, input_lengths, output_lengths,
                                     state)
                else:
                    res = model.step(sess, x, y, input_lengths, output_lengths,
                                     state, summaries)
                    summary_writer.add_summary(res["summary_out"],
                                               current_step)
                    loss = res["loss"]
                    perplexity = np.exp(loss)
                    count += 1
                    print("{0}/{1}({2}), perplexity {3}".format(
                        current_step + 1,
                        FLAGS.num_epochs * data_loader.num_batches, e,
                        perplexity))
                state = res["final_state"]

                if (current_step + 1) % 2000 == 0:
                    count = 0
                    summary_writer.flush()
                    save_path = saver.save(sess, checkpoint)
                    print("Model saved in file:", save_path)

                current_step = tf.train.global_step(sess, model.global_step)

        summary_writer.close()
        save_path = saver.save(sess, checkpoint)
        print("Model saved in file:", save_path)