예제 #1
0
    def get_dataset(self, data_path, n_workers=4, dataset_args={}):
        """ Load data and return Dataset objects for training and validating.

        Args:
            data_path (str): Path to the data.
            valid_ratio (float): Ratio of the data to used as valid data.
        """
        self.logging.info('loading dataset...')
        with open(data_path) as f:
            dataset = json.load(f)

        self.logging.info('preprocessing data...')

        results = [None] * n_workers
        with Pool(processes=n_workers) as pool:
            for i in range(n_workers):
                batch_start = (len(dataset) // n_workers) * i
                if i == n_workers - 1:
                    batch_end = len(dataset)
                else:
                    batch_end = (len(dataset) // n_workers) * (i + 1)

                batch = dataset[batch_start: batch_end]
                results[i] = pool.apply_async(self.preprocess_samples, [batch])

            pool.close()
            pool.join()

        processed = []
        for result in results:
            processed += result.get()

        padding = self.embedding.to_index('</s>')
        return DialogDataset(processed, padding=padding, **dataset_args)
예제 #2
0
def main():
    args = get_args()
    setup_seed(args["seed"])
    assert args["max_steps"] > 0 or args["num_epochs"] > 0
    output_dir = os.path.join("output", args["name"])
    data_dir = os.path.join("data", args["dataset"])
    args["output_dir"] = output_dir
    args["data_dir"] = data_dir
    while not os.path.exists(output_dir):
        if args["local_rank"] in [-1, 0]:
            os.mkdir(output_dir)
    logger = create_logger(os.path.join(output_dir, 'train.log'), local_rank=args["local_rank"])
    if args["local_rank"] in [-1, 0]:
        logger.info(args)
        with open(os.path.join(output_dir, "args.json"), mode="w") as f:
            json.dump(args, f)
    # code for distributed training
    if args["local_rank"] != -1:
        device = torch.device("cuda:{}".format(args["local_rank"]))
        torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=4)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() and not args["no_cuda"] else "cpu")
    
    # slot & gate
    with open(os.path.join(data_dir, "slot_map.json")) as f:
        slot_map = json.load(f)
    with open(os.path.join(data_dir, "gate_label.txt")) as f:
        gate_label = {line.strip(): i for i, line in enumerate(f)}
    
    if args["encoder"] == "bert":
        if len(args["special_tokens"]) > 0 and os.path.exists(args["special_tokens"]):
            with open(args["special_tokens"]) as f:
                special_tokens = f.read().strip().split("\n")
            tokenizer = BertTokenizer.from_pretrained(args["pre_model"], additional_special_tokens=special_tokens)
        else:
            tokenizer = BertTokenizer.from_pretrained(args["pre_model"])
        sp_ids = get_special_ids(tokenizer)
        model = DialogueModel(
            args["pre_model"], 0, 0, len(slot_map), len(gate_label), args["n_layer"],
            args["n_head"], args["dropout"], args["pre_layer_norm"], device, sp_ids["pad_id"])
    else:
        tokenizer = Tokenizer(os.path.join(data_dir, "vocab.txt"), True)
        sp_ids = get_special_ids(tokenizer)
        model = DialogueModel(
            args["encoder"], len(tokenizer), args["hidden_size"], len(slot_map), len(gate_label), 
            args["n_layer"], args["n_head"], args["dropout"], args["pre_layer_norm"], device, sp_ids["pad_id"])

    # train_dataset
    train_pkl = os.path.join(data_dir, "train_dials_{}.pkl".format(len(tokenizer)))
    if os.path.exists(train_pkl):
        train_data = train_pkl
        logger.info("load training cache from {}".format(train_pkl))
    else:
        train_data, domain_counter = get_data(os.path.join(data_dir, "train_dials.json"))
        logger.info("Traning domain_counter: {}".format(domain_counter))
    train_dataset = DialogDataset(
        train_data, tokenizer, slot_map, gate_label, args["max_seq_len"], args["max_resp_len"], True,
        sp_ids["user_type_id"], sp_ids["sys_type_id"], sp_ids["belief_type_id"],
        sp_ids["pad_id"], sp_ids["eos_id"], sp_ids["cls_id"], sp_ids["belief_sep_id"]
    )
    if not os.path.exists(train_pkl) and args["local_rank"] in [-1, 0]:
        with open(train_pkl, mode="wb") as f:
            pickle.dump(train_dataset.data, f)
        logger.info("save training cache to {}".format(train_pkl))
    # test_dataset
    test_pkl = os.path.join(data_dir, "test_dials_{}.pkl".format(len(tokenizer)))
    if os.path.exists(test_pkl):
        test_data = test_pkl
        logger.info("load test cache from {}".format(test_pkl))
    else:
        test_data, domain_counter = get_data(os.path.join(data_dir, "test_dials.json"))
        logger.info("Test domain_counter: {}".format(domain_counter))
    test_dataset = DialogDataset(
        test_data, tokenizer, slot_map, gate_label, args["max_seq_len"], args["max_resp_len"], False,
        sp_ids["user_type_id"], sp_ids["sys_type_id"], sp_ids["belief_type_id"],
        sp_ids["pad_id"], sp_ids["eos_id"], sp_ids["cls_id"], sp_ids["belief_sep_id"]
    )
    if not os.path.exists(test_pkl) and args["local_rank"] in [-1, 0]:
        with open(test_pkl, mode="wb") as f:
            pickle.dump(test_dataset.data, f)
        logger.info("save test cache to {}".format(test_pkl))

    trainer = Trainer(model, tokenizer, sp_ids, slot_map, gate_label, train_dataset, test_dataset, args, logger, device)
    if args["local_rank"] in [-1, 0]:
        logger.info("Start training")

    for epoch in range(1, args["num_epochs"]):
        logger.info("Epoch {} start, Cur step: {}".format(epoch, trainer.total_step))
        total_step = trainer.train(args["max_steps"])
        if total_step > args["max_steps"]:
            logger.info("Reach the max steps")
            break