def save_checkpoint(self, val_loss, model, model_dir): '''Saves model when validation loss decrease.''' if self.verbose: self.logger.info(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') model.save_pretrained(model_dir) TOKENIZER.save_pretrained(model_dir) self.val_loss_min = val_loss
def test_one_to_many(task_load): score_dicts = [] for ep in range(args.n_train_epochs[task_load]): model_dir = get_model_dir([task_load]) model_path = os.path.join(model_dir, 'model-{}'.format(ep+1)) config_path = os.path.join(model_dir,CONFIG_NAME) gen_token = get_gen_token(task_load) TOKENIZER.add_tokens([gen_token]) SPECIAL_TOKENS[task_load] = gen_token SPECIAL_TOKEN_IDS[task_load] = TOKENIZER.convert_tokens_to_ids(gen_token) model_config = CONFIG_CLASS.from_json_file(config_path) model = MODEL_CLASS(model_config).cuda().eval() state_dict = torch.load(model_path, map_location='cuda:0') model.load_state_dict(state_dict) if not args.fp32: model = FP16_Module(model) model.ep = ep model.model_dir = model_dir logger.info("task: {}, epoch: {}".format(task_load, ep+1)) score_dict = {k:None for k in args.tasks} with torch.no_grad(): for task_eval in args.tasks: test_one_to_one(task_load, task_eval, model, score_dict) logger.info("score: {}".format(score_dict)) score_dicts.append(score_dict) with open(os.path.join(model_dir, "metrics.json"),"w") as f: json.dump(score_dicts, f)
def read_extra_data(gen_path, train_extra_data): with open(gen_path,"r") as lm_file: reader = csv.reader(lm_file,delimiter=',') next(reader) for row in reader: row = TOKENIZER.encode(row[0].strip()) train_extra_data.append(row)
def get_real_data(task, train_extra_data, accum=True, encode=True): task_idx = args.tasks.index(task) gen_size = DATA_ATTRS[task]["train"]["data_size"] if accum: prev_tasks = args.tasks[:task_idx] gen_size = int(np.ceil(gen_size * args.gen_lm_sample_percentage))//len(prev_tasks) else: prev_tasks = [args.tasks[task_idx-1]] gen_size = int(gen_size * args.gen_lm_sample_percentage) datum = [] for prev_task in prev_tasks: with open(TASK_DICT[prev_task]["train"],"r") as f: data = data_expand(json.load(f)["data"]) indices = np.random.choice(range(len(data)), gen_size) for i in indices: d = parse_single_real_data(data[i],prev_task) datum.append(d) if encode: train_extra_data.append(TOKENIZER.encode(d)) model_dir = get_model_dir([prev_task]) dump_path = os.path.join(model_dir,"real.csv") write_extra_data(dump_path, datum) return dump_path
def parallel_tokenization(self, d): examples = [] context = TOKENIZER.encode(d["context"]) max_a_len = 0 for i3, qa in enumerate(d["qas"]): question = TOKENIZER.encode(qa["question"]) raw_answers = qa["answers"] if len(raw_answers) == 0: assert qa["is_impossible"] raw_answers.append({"text": ""}) answer = [] for i, raw_answer in enumerate(raw_answers): answer.extend(TOKENIZER.encode(raw_answer["text"])) if i != len(raw_answers) - 1: answer.append(self.pad_token) max_a_len = max(max_a_len, len(answer)) examples.append(self.parse_example(self.gen_token, context, question, answer, qa.get("id", 0 if not args.test_training_set else d["pid"]+"_%d"%i3))) return examples, max_a_len
def create_extra_data(task, prev_task, model, train_extra_data): if args.real_sample: logger.info(f"using real data as extra data") return get_real_data(task, train_extra_data) task_cnt = args.tasks.index(task) model_dir = get_model_dir([prev_task]) gen_path = os.path.join(model_dir, "lm.csv") if os.path.exists(gen_path): logger.info(f"extra data exists in {gen_path}, read it!") return read_extra_data(gen_path, train_extra_data) gen_size = DATA_ATTRS[task]["train"]["data_size"] gen_size = int(np.ceil(gen_size * args.gen_lm_sample_percentage)) gen_size -= (gen_size % task_cnt) if args.debug: gen_size = task_cnt model.eval() need_process = OrderedDict() qa_results = [] for task_name in args.tasks[:task_cnt]: qa_results.extend([ torch.tensor([SPECIAL_TOKEN_IDS[task_name]]) for _ in range(gen_size // task_cnt) ]) all_pasts = [[ torch.empty(2, MODEL_CONFIG.n_head, 0, MODEL_CONFIG.n_embd // MODEL_CONFIG.n_head, dtype=torch.float if args.fp32 else torch.half).cuda() for _ in range(gen_size) ] for __ in range(MODEL_CONFIG.n_layer)] max_tot_lens = [args.max_len for _ in range(gen_size)] for i in range(gen_size): need_process.update([[i, None]]) if len(need_process) > int(args.memory_sizes[0] * 0.12): sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens) sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens) model.train() qa_results = [res.tolist() for res in qa_results] train_extra_data.extend(qa_results) qa_results = [TOKENIZER.decode(res) for res in qa_results] write_extra_data(gen_path, qa_results)
def test_one_to_one(task_load, task_eval, model, score_dict): logger.info("start to test { task: %s (load) %s (eval), seq train type: %s }" % (task_load, task_eval, args.seq_train_type)) test_qadata = QADataset(TASK_DICT[task_eval]["test"] , "test", SPECIAL_TOKEN_IDS[task_load]).sort() max_a_len = test_qadata.max_a_len test_dataloader = create_dataloader(test_qadata, "test") n_examples = len(test_qadata) logger.info("len of test dataset: {}".format(n_examples)) need_process = OrderedDict() qa_results = [0 for _ in range(n_examples)] all_pasts = [[0 for _ in range(n_examples)] for __ in range(MODEL_CONFIG.n_layer)] max_tot_lens = [0 for _ in range(n_examples)] cnt = 0 for n_steps, (cqs, len_cqs, _, _, _, _, _) in enumerate(test_dataloader): # assume n_gpus == 1 cqs = cqs[0] len_cqs = len_cqs[0] n_inputs = cqs.shape[0] all_outputs = model(input_ids=cqs.cuda()) outputs = all_outputs[0] if args.model_name == "gpt2": pasts = all_outputs[1] next_logits = outputs[range(n_inputs), len_cqs-1, :] / args.temperature_qa next_tokens = logits_to_tokens(next_logits).cpu() for i in range(n_inputs): max_tot_lens[cnt] = max_a_len + test_qadata[cnt][1] qa_results[cnt] = cqs[i][:len_cqs[i]] if next_tokens[i] != SPECIAL_TOKEN_IDS["eos_token"]: qa_results[cnt] = torch.cat((cqs[i][:len_cqs[i]], next_tokens[i])) if len(qa_results[cnt]) not in [max_tot_lens[cnt], args.max_len]: need_process.update([[cnt, None]]) if args.model_name == "gpt2": for layer_id in range(MODEL_CONFIG.n_layer): all_pasts[layer_id][cnt] = pasts[layer_id][:, i, ..., :len_cqs[i], :].type(torch.float32 if args.fp32 else torch.half) cnt += 1 if len(need_process) > int(12 * args.memory_sizes[0] / cqs.shape[1]): # dynamic threshold to avoid out of memory sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens) sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens) if task_eval in ['wikisql','woz.en','multinli.in.out']: ids = test_qadata.get_indices() test_qadata.sort_by_index() qa_results = [x[1] for x in sorted([(i, g) for i, g in zip(ids, qa_results)])] for i in range(len(test_qadata)): _, len_cq, _, _, Y, _, _, _ = test_qadata[i] if task_eval in ['wikisql','woz.en']: Y = test_qadata.answers[i] else: Y = list(filter(lambda x: x != -1, Y))[:-1] # remove eos Y = ' '.join([str(y) for y in Y]).split(str(SPECIAL_TOKEN_IDS["pad_token"])) Y = [TOKENIZER.decode(list(map(int, y.split()))) for y in Y] qa_results[i] = [TOKENIZER.decode(qa_results[i].tolist()[len_cq:]), Y] get_test_score(task_eval, qa_results, score_dict) model_dir = model.model_dir ep = model.ep results_path = os.path.join(model_dir,"qa_{}_{}.csv".format(task_eval,ep+1)) if not args.debug: with open(results_path, "w",encoding="utf-8") as f: qa_writer = csv.writer(f,delimiter=',') qa_writer.writerow(["y","pred"]) for pred, y in qa_results: if task_eval == 'wikisql': y = y["answer"] elif task_eval == 'woz.en': y = y[1] qa_writer.writerow([y,pred]) return model, score_dict
def train(task_ids, model): tasks = [args.tasks[task_id] for task_id in task_ids] logger.info("start to train { task: %s, seq train type: %s }" % (tasks, args.seq_train_type)) model_dir = get_model_dir(tasks) make_dir(model_dir) #train_dataset = [(TASK_DICT[t]["train"] if not args.seq_distil else TASK_DICT[t]["train"].replace("train", "distil")) for t in tasks] train_dataset = [ swap_name(TASK_DICT[t]["train"], args.seq_distil, args.ref1) for t in tasks ] train_extra_data = [] if "lll" in args.seq_train_type and task_ids[0] > 0 and not args.skip_tasks: prev_task = args.tasks[task_ids[0] - 1] with torch.no_grad(): create_extra_data(tasks[0], prev_task, model, train_extra_data) elif "gem" in args.seq_train_type and task_ids[0] > 0: get_real_data(tasks[0], train_extra_data, accum=False, encode=True) args.memory_data.append(train_extra_data) train_extra_data = [] logger.info('extra training data size: {}'.format(len(train_extra_data))) if not model: # which_model_to_load = model_dir if os.path.isfile(os.path.join(model_dir, FINAL_SAVE_NAME)) else args.model_name model = MODEL_CLASS.from_pretrained(args.model_name).cuda() model.resize_token_embeddings(len(TOKENIZER)) if not args.fp32: model = FP16_Module(model) gen_token = get_gen_token(tasks[0]) TOKENIZER.add_tokens([gen_token]) TOKENIZER.save_pretrained(model_dir) SPECIAL_TOKENS[tasks[0]] = gen_token SPECIAL_TOKEN_IDS[tasks[0]] = TOKENIZER.convert_tokens_to_ids(gen_token) logger.info('gen token = {} , gen token id = {}'.format( gen_token, SPECIAL_TOKEN_IDS[tasks[0]])) MODEL_CONFIG.vocab_size = len(TOKENIZER) MODEL_CONFIG.to_json_file(os.path.join(model_dir, CONFIG_NAME)) global TOKENS_WEIGHT if len(TOKENIZER) != TOKENS_WEIGHT.shape[0]: TOKENS_WEIGHT = torch.cat((TOKENS_WEIGHT, torch.ones([1]).cuda())) if args.skip_tasks and len(tasks) == 1: logger.info("*********** skip task: {} ***********".format(tasks[0])) if tasks[0] in args.skip_tasks: if len(args.skip_tasks) == 1: model_dir = get_model_dir(tasks) model_path = os.path.join(model_dir, FINAL_SAVE_NAME) config_path = os.path.join(model_dir, CONFIG_NAME) model_config = CONFIG_CLASS.from_json_file(config_path) model = MODEL_CLASS(model_config).cuda() state_dict = torch.load(model_path) model.load_state_dict(state_dict) if not args.fp32: model = FP16_Module(model) if args.seq_train_type in REG_TYPE_KEYS: logger.info("calulating reg_params ...") train_qadata = QADataset(train_dataset, "train", SPECIAL_TOKEN_IDS[tasks[0]], train_extra_data) max_train_batch_size = max( len(train_qadata) // args.min_n_steps, args.min_batch_size) train_dataloader = create_dataloader( train_qadata, "train", max_train_batch_size) parallel_model = DataParallelModel(WrapModel(model), args.device_ids) regularizer = REG_TYPES[args.seq_train_type]( model, parallel_model, [train_dataloader], tasks[0]) regularizer.task_start_do() regularizer.task_end_do() torch.save(model.state_dict(), os.path.join(model_dir, FINAL_SAVE_NAME)) logger.info("done reg_params!") args.skip_tasks.remove(tasks[0]) return model model.resize_token_embeddings( len(TOKENIZER) if not args.multitask_specific else len(TOKENIZER) + 4) if args.multitask_specific: for i in range(4): TOKENS_WEIGHT = torch.cat((TOKENS_WEIGHT, torch.ones([1]).cuda())) if args.distil: teacher_model = MODEL_CLASS.from_pretrained(args.model_name).cuda() teacher_vocab_size = json.load( open("models/gpt2/lll/{task}_0.2/{task}/config.json".format( task=tasks[0])))['vocab_size'] teacher_model.resize_token_embeddings(teacher_vocab_size) print("load teacher model from {}".format( "models/gpt2/lll/{task}_0.2/{task}/model-finish".format( task=tasks[0]))) teacher_model.load_state_dict( torch.load("models/gpt2/lll/{task}_0.2/{task}/model-finish".format( task=tasks[0]))) if not args.fp32: teacher_model = FP16_Module(teacher_model) teacher_model.eval() teacher_model = DataParallelModel(WrapModel(teacher_model), args.device_ids) if not args.fp32: # again because resize_token_embeddings makes embedding layer fp32 model = FP16_Module(model) parallel_model = DataParallelModel(WrapModel(model), args.device_ids) train_qadata = QADataset(train_dataset, "train", SPECIAL_TOKEN_IDS[tasks[0]], train_extra_data) max_train_batch_size = max( len(train_qadata) // args.min_n_steps, args.min_batch_size) train_dataloader = create_dataloader(train_qadata, "train", max_train_batch_size) if not args.unbound and args.seq_train_type not in [ "multitask", "multilm" ]: #n_train_epochs = TASK_DICT[tasks[0]]["n_train_epochs"] n_train_epochs = args.n_train_epochs[tasks[0]] else: n_train_epochs = args.n_train_epochs['_'.join(tasks)] n_train_optimization_steps = len(train_qadata) * n_train_epochs logger.info( 'len of train dataset: {} , max train batch size {} , num of opt steps: {}' .format(len(train_qadata), max_train_batch_size, n_train_optimization_steps)) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if "gem" in args.seq_train_type: model.task_id = task_ids[0] if not hasattr(model, "grad_dims"): model.grad_dims = [] for param in model.parameters(): model.grad_dims.append(param.data.numel()) if not hasattr(model, "grads"): model.grads = torch.zeros(sum(model.grad_dims), len(args.tasks)) model.grads = model.grads.cuda() if args.seq_train_type in REG_TYPE_KEYS: optimizer = Weight_Regularized_AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) if not args.fp32: optimizer = FP16_Optimizer(optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={ 'scale_window': 100, 'min_scale': 1, 'delayed_shift': 2 }) scheduler = AnnealingLR(optimizer, start_lr=args.learning_rate, warmup_iter=int(args.n_warmup_ratio * len(train_qadata)), num_iters=int(n_train_optimization_steps), decay_style=args.decay_style) train_loss_fct = DataParallelCriterion( CrossEntropyLoss(ignore_index=FILL_VAL, weight=TOKENS_WEIGHT), args.device_ids) if args.distil: kd_loss_fct = DataParallelCriterion( nn.KLDivLoss(reduction="batchmean"), args.device_ids) if args.seq_train_type in REG_TYPE_KEYS: copy_train_dataloader = create_dataloader(train_qadata, "train", max_train_batch_size) prev_task = args.tasks[task_ids[0] - 1] regularizer = REG_TYPES[args.seq_train_type](model, parallel_model, [copy_train_dataloader], tasks[0], prev_task) regularizer.task_start_do() tot_n_steps = 0 train_once = TrainStep(model, optimizer, scheduler) if "gem" in args.seq_train_type and task_ids[0] != 0: gem_step = GEMStep(model, parallel_model, train_loss_fct, optimizer) model.train() for ep in range(n_train_epochs): cum_loss, cum_qa_loss, cum_lm_loss, cur_n_inputs = 0, 0, 0, 0 for n_steps, (_, _, cqa, _, Y, gen_X, gen_Y, is_extra) in enumerate(train_dataloader): n_inputs = sum(_cqa.shape[0] for _cqa in cqa) if args.multitask_specific: for i in range(len(is_extra)): gen_X[i][:, 0] += is_extra[i] is_extra[i] = is_extra[i] * 0 for i in range(len(cqa)): cqa[i] = (cqa[i].to(args.device_ids[i]), ) Y[i] = Y[i].to(args.device_ids[i]) gen_X[i] = (gen_X[i].to(args.device_ids[i]), ) gen_Y[i] = gen_Y[i].to(args.device_ids[i]) is_extra[i] = is_extra[i].to(args.device_ids[i]) if args.distil: losses = get_distil_losses(teacher_model, parallel_model, cqa, Y, gen_X, gen_Y, is_extra, kd_loss_fct, train_loss_fct, args.temperature_kd, pad_idx=FILL_VAL) else: losses = get_losses(parallel_model, cqa, Y, gen_X, gen_Y, train_loss_fct) loss = sum(losses) if "gem" in args.seq_train_type and task_ids[0] != 0: gem_step(task_ids[0]) train_once(loss, n_inputs) qa_loss = losses[0].item() * n_inputs lm_loss = losses[1].item() * n_inputs cum_loss += (qa_loss + lm_loss) cum_qa_loss += qa_loss cum_lm_loss += lm_loss cur_n_inputs += n_inputs if (n_steps + 1) % args.logging_steps == 0: logger.info( 'progress {:.3f} , lr {:.1E} , loss {:.3f} , qa loss {:.3f} , lm loss {:.3f} , avg batch size {:.1f}' .format(ep + cur_n_inputs / len(train_qadata), scheduler.get_lr(), cum_loss / cur_n_inputs, cum_qa_loss / cur_n_inputs, cum_lm_loss / cur_n_inputs, cur_n_inputs / (n_steps + 1))) torch.save(model.state_dict(), os.path.join(model_dir, SAVE_NAME + str(ep + 1))) tot_n_steps += (n_steps + 1) logger.info( 'epoch {}/{} done , tot steps {} , lr {:.1E} , loss {:.2f} , qa loss {:.2f} , lm loss {:.2f} , avg batch size {:.1f}' .format(ep + 1, n_train_epochs, tot_n_steps, scheduler.get_lr(), cum_loss / cur_n_inputs, cum_qa_loss / cur_n_inputs, cum_lm_loss / cur_n_inputs, cur_n_inputs / (n_steps + 1))) # task end do for reg if args.seq_train_type in REG_TYPE_KEYS: regularizer.task_end_do() torch.save(model.state_dict(), os.path.join(model_dir, FINAL_SAVE_NAME)) return model