def get_dataset(self, data_path, n_workers=4, dataset_args={}): """ Load data and return Dataset objects for training and validating. Args: data_path (str): Path to the data. valid_ratio (float): Ratio of the data to used as valid data. """ self.logging.info('loading dataset...') with open(data_path) as f: dataset = json.load(f) self.logging.info('preprocessing data...') results = [None] * n_workers with Pool(processes=n_workers) as pool: for i in range(n_workers): batch_start = (len(dataset) // n_workers) * i if i == n_workers - 1: batch_end = len(dataset) else: batch_end = (len(dataset) // n_workers) * (i + 1) batch = dataset[batch_start: batch_end] results[i] = pool.apply_async(self.preprocess_samples, [batch]) pool.close() pool.join() processed = [] for result in results: processed += result.get() padding = self.embedding.to_index('</s>') return DialogDataset(processed, padding=padding, **dataset_args)
def main(): args = get_args() setup_seed(args["seed"]) assert args["max_steps"] > 0 or args["num_epochs"] > 0 output_dir = os.path.join("output", args["name"]) data_dir = os.path.join("data", args["dataset"]) args["output_dir"] = output_dir args["data_dir"] = data_dir while not os.path.exists(output_dir): if args["local_rank"] in [-1, 0]: os.mkdir(output_dir) logger = create_logger(os.path.join(output_dir, 'train.log'), local_rank=args["local_rank"]) if args["local_rank"] in [-1, 0]: logger.info(args) with open(os.path.join(output_dir, "args.json"), mode="w") as f: json.dump(args, f) # code for distributed training if args["local_rank"] != -1: device = torch.device("cuda:{}".format(args["local_rank"])) torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=4) else: device = torch.device("cuda" if torch.cuda.is_available() and not args["no_cuda"] else "cpu") # slot & gate with open(os.path.join(data_dir, "slot_map.json")) as f: slot_map = json.load(f) with open(os.path.join(data_dir, "gate_label.txt")) as f: gate_label = {line.strip(): i for i, line in enumerate(f)} if args["encoder"] == "bert": if len(args["special_tokens"]) > 0 and os.path.exists(args["special_tokens"]): with open(args["special_tokens"]) as f: special_tokens = f.read().strip().split("\n") tokenizer = BertTokenizer.from_pretrained(args["pre_model"], additional_special_tokens=special_tokens) else: tokenizer = BertTokenizer.from_pretrained(args["pre_model"]) sp_ids = get_special_ids(tokenizer) model = DialogueModel( args["pre_model"], 0, 0, len(slot_map), len(gate_label), args["n_layer"], args["n_head"], args["dropout"], args["pre_layer_norm"], device, sp_ids["pad_id"]) else: tokenizer = Tokenizer(os.path.join(data_dir, "vocab.txt"), True) sp_ids = get_special_ids(tokenizer) model = DialogueModel( args["encoder"], len(tokenizer), args["hidden_size"], len(slot_map), len(gate_label), args["n_layer"], args["n_head"], args["dropout"], args["pre_layer_norm"], device, sp_ids["pad_id"]) # train_dataset train_pkl = os.path.join(data_dir, "train_dials_{}.pkl".format(len(tokenizer))) if os.path.exists(train_pkl): train_data = train_pkl logger.info("load training cache from {}".format(train_pkl)) else: train_data, domain_counter = get_data(os.path.join(data_dir, "train_dials.json")) logger.info("Traning domain_counter: {}".format(domain_counter)) train_dataset = DialogDataset( train_data, tokenizer, slot_map, gate_label, args["max_seq_len"], args["max_resp_len"], True, sp_ids["user_type_id"], sp_ids["sys_type_id"], sp_ids["belief_type_id"], sp_ids["pad_id"], sp_ids["eos_id"], sp_ids["cls_id"], sp_ids["belief_sep_id"] ) if not os.path.exists(train_pkl) and args["local_rank"] in [-1, 0]: with open(train_pkl, mode="wb") as f: pickle.dump(train_dataset.data, f) logger.info("save training cache to {}".format(train_pkl)) # test_dataset test_pkl = os.path.join(data_dir, "test_dials_{}.pkl".format(len(tokenizer))) if os.path.exists(test_pkl): test_data = test_pkl logger.info("load test cache from {}".format(test_pkl)) else: test_data, domain_counter = get_data(os.path.join(data_dir, "test_dials.json")) logger.info("Test domain_counter: {}".format(domain_counter)) test_dataset = DialogDataset( test_data, tokenizer, slot_map, gate_label, args["max_seq_len"], args["max_resp_len"], False, sp_ids["user_type_id"], sp_ids["sys_type_id"], sp_ids["belief_type_id"], sp_ids["pad_id"], sp_ids["eos_id"], sp_ids["cls_id"], sp_ids["belief_sep_id"] ) if not os.path.exists(test_pkl) and args["local_rank"] in [-1, 0]: with open(test_pkl, mode="wb") as f: pickle.dump(test_dataset.data, f) logger.info("save test cache to {}".format(test_pkl)) trainer = Trainer(model, tokenizer, sp_ids, slot_map, gate_label, train_dataset, test_dataset, args, logger, device) if args["local_rank"] in [-1, 0]: logger.info("Start training") for epoch in range(1, args["num_epochs"]): logger.info("Epoch {} start, Cur step: {}".format(epoch, trainer.total_step)) total_step = trainer.train(args["max_steps"]) if total_step > args["max_steps"]: logger.info("Reach the max steps") break