def main(args): ontology = json.load(open(os.path.join(args.data_root, args.ontology_data))) slot_meta, _ = make_slot_meta(ontology) tokenizer = BertTokenizer.from_pretrained("dsksd/bert-ko-small-minimal") out_file = '/opt/ml/code/p3-dst-chatting-day/SomDST/pickles/test_data_raw.pkl' if os.path.exists(out_file): print("Pickles are exist!") with open(out_file, 'rb') as f: test_data_raw = pickle.load(f) # with open(out_path+'/test_data.pkl', 'rb') as f: # test_data = pickle.load(f) print("Pickles brought!") else: print("Pickles are not exist!") test_data_raw = prepare_dataset_eval(data_path=args.test_data_path, tokenizer=tokenizer, slot_meta=slot_meta, n_history=args.n_history, max_seq_length=args.max_seq_length, op_code=args.op_code) # test_data = WosDataset(train_data_raw, # tokenizer, # slot_meta, # args.max_seq_length, # rng, # ontology, # args.word_dropout, # args.shuffle_state, # args.shuffle_p) with open(out_file, 'wb') as f: pickle.dump(test_data_raw, f) # with open(out_path+'/test_data.pkl', 'wb') as f: # pickle.dump(test_data, f) print("Pickles saved!") print("# test examples %d" % len(test_data_raw)) model_config = BertConfig.from_json_file(args.bert_config_path) model_config.dropout = 0.1 op2id = OP_SET[args.op_code] model = SomDST(model_config, len(op2id), len(domain2id), op2id['update']) ckpt = torch.load(args.model_ckpt_path, map_location='cpu') model.load_state_dict(ckpt) model.eval() model.to(device) print("Model is loaded") inference_model(model, test_data_raw, tokenizer, slot_meta, args.op_code)
def main(args): ontology = json.load(open(os.path.join(args.data_root, args.ontology_data))) slot_meta, _ = make_slot_meta(ontology) tokenizer = BertTokenizer.from_pretrained("dsksd/bert-ko-small-minimal") data = prepare_dataset(os.path.join(args.data_root, args.test_data), tokenizer, slot_meta, args.n_history, args.max_seq_length, args.op_code) model_config = BertConfig.from_json_file(args.bert_config_path) model_config.dropout = 0.1 op2id = OP_SET[args.op_code] model = SomDST(model_config, len(op2id), len(domain2id), op2id['update']) ckpt = torch.load(args.model_ckpt_path, map_location='cpu') model.load_state_dict(ckpt) model.eval() model.to(device) if args.eval_all: model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code, False, False, False) model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code, False, False, True) model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code, False, True, False) model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code, False, True, True) model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code, True, False, False) model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code, True, True, False) model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code, True, False, True) model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code, True, True, True) else: model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code, args.gt_op, args.gt_p_state, args.gt_gen)
def main(args): def worker_init_fn(worker_id): np.random.seed(args.random_seed + worker_id) n_gpu = 0 if torch.cuda.is_available(): n_gpu = torch.cuda.device_count() np.random.seed(args.random_seed) random.seed(args.random_seed) rng = random.Random(args.random_seed) torch.manual_seed(args.random_seed) if n_gpu > 0: torch.cuda.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if not os.path.exists(args.save_dir): os.mkdir(args.save_dir) ontology = json.load(open(args.ontology_data)) slot_meta, ontology = make_slot_meta(ontology) op2id = OP_SET[args.op_code] print(op2id) tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True) train_data_raw = prepare_dataset(data_path=args.train_data_path, tokenizer=tokenizer, slot_meta=slot_meta, n_history=args.n_history, max_seq_length=args.max_seq_length, op_code=args.op_code) train_data = MultiWozDataset(train_data_raw, tokenizer, slot_meta, args.max_seq_length, rng, ontology, args.word_dropout, args.shuffle_state, args.shuffle_p) print("# train examples %d" % len(train_data_raw)) dev_data_raw = prepare_dataset(data_path=args.dev_data_path, tokenizer=tokenizer, slot_meta=slot_meta, n_history=args.n_history, max_seq_length=args.max_seq_length, op_code=args.op_code) print("# dev examples %d" % len(dev_data_raw)) test_data_raw = prepare_dataset(data_path=args.test_data_path, tokenizer=tokenizer, slot_meta=slot_meta, n_history=args.n_history, max_seq_length=args.max_seq_length, op_code=args.op_code) print("# test examples %d" % len(test_data_raw)) model_config = BertConfig.from_json_file(args.bert_config_path) model_config.dropout = args.dropout model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob model_config.hidden_dropout_prob = args.hidden_dropout_prob model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain) if not os.path.exists(args.bert_ckpt_path): args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets') ckpt = torch.load(args.bert_ckpt_path, map_location='cpu') model.encoder.bert.load_state_dict(ckpt) # re-initialize added special tokens ([SLOT], [NULL], [EOS]) model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02) model.encoder.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02) model.encoder.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02) model.to(device) num_train_steps = int(len(train_data_raw) / args.batch_size * args.n_epochs) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] enc_param_optimizer = list(model.encoder.named_parameters()) enc_optimizer_grouped_parameters = [ {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr) enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup), t_total=num_train_steps) dec_param_optimizer = list(model.decoder.parameters()) dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr) dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup), t_total=num_train_steps) if n_gpu > 1: model = torch.nn.DataParallel(model) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size, collate_fn=train_data.collate_fn, num_workers=args.num_workers, worker_init_fn=worker_init_fn) loss_fnc = nn.CrossEntropyLoss() best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0} for epoch in range(args.n_epochs): batch_loss = [] model.train() for step, batch in enumerate(train_dataloader): batch = [b.to(device) if not isinstance(b, int) else b for b in batch] input_ids, input_mask, segment_ids, state_position_ids, op_ids,\ domain_ids, gen_ids, max_value, max_update = batch if rng.random() < args.decoder_teacher_forcing: # teacher forcing teacher = gen_ids else: teacher = None domain_scores, state_scores, gen_scores = model(input_ids=input_ids, token_type_ids=segment_ids, state_positions=state_position_ids, attention_mask=input_mask, max_value=max_value, op_ids=op_ids, max_update=max_update, teacher=teacher) loss_s = loss_fnc(state_scores.view(-1, len(op2id)), op_ids.view(-1)) loss_g = masked_cross_entropy_for_value(gen_scores.contiguous(), gen_ids.contiguous(), tokenizer.vocab['[PAD]']) loss = loss_s + loss_g if args.exclude_domain is not True: loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)), domain_ids.view(-1)) loss = loss + loss_d batch_loss.append(loss.item()) loss.backward() enc_optimizer.step() enc_scheduler.step() dec_optimizer.step() dec_scheduler.step() model.zero_grad() if step % 100 == 0: if args.exclude_domain is not True: print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \ % (epoch+1, args.n_epochs, step, len(train_dataloader), np.mean(batch_loss), loss_s.item(), loss_g.item(), loss_d.item())) else: print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \ % (epoch+1, args.n_epochs, step, len(train_dataloader), np.mean(batch_loss), loss_s.item(), loss_g.item())) batch_loss = [] if (epoch+1) % args.eval_epoch == 0: eval_res = model_evaluation(model, dev_data_raw, tokenizer, slot_meta, epoch+1, args.op_code) if eval_res['joint_acc'] > best_score['joint_acc']: best_score = eval_res model_to_save = model.module if hasattr(model, 'module') else model save_path = os.path.join(args.save_dir, 'model_best.bin') torch.save(model_to_save.state_dict(), save_path) print("Best Score : ", best_score) print("\n") print("Test using best model...") best_epoch = best_score['epoch'] ckpt_path = os.path.join(args.save_dir, 'model_best.bin') model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain) ckpt = torch.load(ckpt_path, map_location='cpu') model.load_state_dict(ckpt) model.to(device) model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code, is_gt_op=False, is_gt_p_state=False, is_gt_gen=False) model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code, is_gt_op=False, is_gt_p_state=False, is_gt_gen=True) model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code, is_gt_op=False, is_gt_p_state=True, is_gt_gen=False) model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code, is_gt_op=False, is_gt_p_state=True, is_gt_gen=True) model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code, is_gt_op=True, is_gt_p_state=False, is_gt_gen=False) model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code, is_gt_op=True, is_gt_p_state=True, is_gt_gen=False) model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code, is_gt_op=True, is_gt_p_state=False, is_gt_gen=True) model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code, is_gt_op=True, is_gt_p_state=True, is_gt_gen=True)
def main(args): def worker_init_fn(worker_id): np.random.seed(args.random_seed + worker_id) n_gpu = 0 if torch.cuda.is_available(): n_gpu = torch.cuda.device_count() np.random.seed(args.random_seed) random.seed(args.random_seed) rng = random.Random(args.random_seed) torch.manual_seed(args.random_seed) if n_gpu > 0: torch.cuda.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if not os.path.exists(args.save_dir): os.mkdir(args.save_dir) ontology = json.load(open(args.ontology_data)) slot_meta, ontology = make_slot_meta(ontology) op2id = OP_SET[args.op_code] # print(op2id) tokenizer = BertTokenizer.from_pretrained("dsksd/bert-ko-small-minimal") out_path = '/opt/ml/code/new-som-dst/pickles' if os.path.exists(out_path): print("Pickles are exist!") with open(out_path + '/train_data_raw.pkl', 'rb') as f: train_data_raw = pickle.load(f) with open(out_path + '/train_data.pkl', 'rb') as f: train_data = pickle.load(f) with open(out_path + '/dev_data_raw.pkl', 'rb') as f: dev_data_raw = pickle.load(f) print("Pickles brought!") else: print("Pickles are not exist!") train_dials, dev_dials = load_dataset(args.train_data_path) print(f"t_d_len : {len(train_dials)}, d_d_len : {len(dev_dials)}") train_data_raw = prepare_dataset(dials=train_dials, tokenizer=tokenizer, slot_meta=slot_meta, n_history=args.n_history, max_seq_length=args.max_seq_length, op_code=args.op_code) # print("train_data_raw is ready") train_data = WosDataset(train_data_raw, tokenizer, slot_meta, args.max_seq_length, rng, ontology, args.word_dropout, args.shuffle_state, args.shuffle_p) dev_data_raw = prepare_dataset(dials=dev_dials, tokenizer=tokenizer, slot_meta=slot_meta, n_history=args.n_history, max_seq_length=args.max_seq_length, op_code=args.op_code) # print(len(dev_data_raw)) os.makedirs(out_path, exist_ok=True) with open(out_path + '/train_data_raw.pkl', 'wb') as f: pickle.dump(train_data_raw, f) with open(out_path + '/train_data.pkl', 'wb') as f: pickle.dump(train_data, f) with open(out_path + '/dev_data_raw.pkl', 'wb') as f: pickle.dump(dev_data_raw, f) print("Pickles saved!") print("# train examples %d" % len(train_data_raw)) print("# dev examples %d" % len(dev_data_raw)) # test_data_raw = prepare_dataset(data_path=args.test_data_path, # tokenizer=tokenizer, # slot_meta=slot_meta, # n_history=args.n_history, # max_seq_length=args.max_seq_length, # op_code=args.op_code) # print("# test examples %d" % len(test_data_raw)) model_config = BertConfig.from_json_file(args.bert_config_path) model_config.dropout = args.dropout model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob model_config.hidden_dropout_prob = args.hidden_dropout_prob model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain) ckpt = torch.load('/opt/ml/outputs/model_20.bin', map_location='cpu') model.load_state_dict(ckpt) print(f"model is loaded!") # if not os.path.exists(args.bert_ckpt_path): # args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, '/opt/ml/code/new-som-dst/assets') # ckpt = torch.load(args.bert_ckpt_path, map_location='cpu') # model.encoder.bert.load_state_dict(ckpt, strict=False) # # re-initialize added special tokens ([SLOT], [NULL], [EOS]) # model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02) # model.encoder.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02) # model.encoder.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02) model.to(device) print() wandb.watch(model) num_train_steps = int( len(train_data_raw) / args.batch_size * args.n_epochs) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] enc_param_optimizer = list(model.encoder.named_parameters()) enc_optimizer_grouped_parameters = [{ 'params': [ p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr) enc_scheduler = get_linear_schedule_with_warmup( enc_optimizer, num_warmup_steps=int(num_train_steps * args.enc_warmup), num_training_steps=num_train_steps) dec_param_optimizer = list(model.decoder.parameters()) dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr) dec_scheduler = get_linear_schedule_with_warmup( dec_optimizer, num_warmup_steps=int(num_train_steps * args.dec_warmup), num_training_steps=num_train_steps) if n_gpu > 1: model = torch.nn.DataParallel(model) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size, collate_fn=train_data.collate_fn, num_workers=args.num_workers, worker_init_fn=worker_init_fn) loss_fnc = nn.CrossEntropyLoss() best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0} for epoch in range(args.n_epochs): batch_loss = [] model.train() for step, batch in enumerate(train_dataloader): batch = [ b.to(device) if not isinstance(b, int) else b for b in batch ] input_ids, input_mask, segment_ids, state_position_ids, op_ids,\ domain_ids, gen_ids, max_value, max_update = batch if rng.random() < args.decoder_teacher_forcing: # teacher forcing teacher = gen_ids else: teacher = None domain_scores, state_scores, gen_scores = model( input_ids=input_ids, token_type_ids=segment_ids, state_positions=state_position_ids, attention_mask=input_mask, max_value=max_value, op_ids=op_ids, max_update=max_update, teacher=teacher) # print(f"input_id : {input_ids[0].shape} {input_ids[0]}") # print(f"segment_id : {segment_ids[0].shape} {segment_ids[0]}") # print(f"slot_position : {state_position_ids[0].shape} {state_position_ids[0]}") # print(f"input_mask : {input_mask[0].shape} {input_mask[0]}") # print(f"state_scores : {state_scores[0].shape} {state_scores[0]}") # print(f"gen_scores : {gen_scores[0].shape} {gen_scores[0]}") # print(f"op_ids : {op_ids.shape, op_ids}") loss_s = loss_fnc(state_scores.view(-1, len(op2id)), op_ids.view(-1)) # print("loss_s", loss_s.shape, loss_s) loss_g = masked_cross_entropy_for_value( gen_scores.contiguous(), # B, J', K, V gen_ids.contiguous(), # B, J', K tokenizer.vocab['[PAD]']) # print("loss_g", loss_g) # print(f"gen_scores : {gen_scores.shape, torch.argmax(gen_scores[0][0], -1)}") # print(f"gen_ids : {gen_ids.shape, gen_ids[0][0], tokenizer.decode(gen_ids[0][0])}") loss = loss_s + loss_g if args.exclude_domain is not True: loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)), domain_ids.view(-1)) loss = loss + loss_d batch_loss.append(loss.item()) loss.backward() enc_optimizer.step() enc_scheduler.step() dec_optimizer.step() dec_scheduler.step() model.zero_grad() if (step + 1) % 100 == 0: if args.exclude_domain is not True: print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \ % (epoch+1, args.n_epochs, step+1, len(train_dataloader), np.mean(batch_loss), loss_s.item(), loss_g.item(), loss_d.item())) else: print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \ % (epoch+1, args.n_epochs, step+1, len(train_dataloader), np.mean(batch_loss), loss_s.item(), loss_g.item())) batch_loss = [] if (epoch + 1) % args.eval_epoch == 0: eval_res = model_evaluation(model, dev_data_raw, tokenizer, slot_meta, epoch + 1, args.op_code) if eval_res['joint_acc'] > best_score['joint_acc']: best_score = eval_res model_to_save = model.module if hasattr(model, 'module') else model save_path = os.path.join(args.save_dir, 'model_best.bin') torch.save(model_to_save.state_dict(), save_path) print("Best Score : ", best_score) print("\n") wandb.log({ 'joint_acc': eval_res['joint_acc'], 'slot_acc': eval_res['slot_acc'], 'slot_f1': eval_res['slot_f1'], 'op_acc': eval_res['op_acc'], 'op_f1': eval_res['op_f1'], 'final_slot_f1': eval_res['final_slot_f1'] }) # save model at 10 epochs if (epoch + 1) % 10 == 0: model_to_save = model.module if hasattr(model, 'module') else model save_path = os.path.join(args.save_dir, f'model_{epoch+1}.bin') torch.save(model_to_save.state_dict(), save_path) print(f"model_{epoch}.bin is saved!")
def train(args): # operation encoder domain encoder op2id = OP_SET[args.op_code] domain2id = DOMAIN2ID args.n_op = len(op2id) args.n_domain = get_domain_nums(domain2id) args.update_id = op2id["update"] # initialize tokenizer tokenizer_module = getattr(import_module("transformers"), f"{args.model_name}Tokenizer") tokenizer = tokenizer_module.from_pretrained(args.pretrained_name_or_path) tokenizer.add_special_tokens({ "additional_special_tokens": [NULL_TOKEN, SLOT_TOKEN], 'eos_token': EOS_TOKEN }) args.vocab_size = len(tokenizer) # load data slot_meta = json.load(open(os.path.join(args.data_dir, "slot_meta.json"))) train_data, dev_data, dev_labels = load_somdst_dataset( os.path.join(args.data_dir, "train_dials.json")) train_examples = get_somdst_examples_from_dialogues( data=train_data, n_history=args.n_history) dev_examples = get_somdst_examples_from_dialogues(data=dev_data, n_history=args.n_history) # preprocessing preprocessor = SomDSTPreprocessor(slot_meta=slot_meta, src_tokenizer=tokenizer, max_seq_length=args.max_seq_length, word_dropout=0.1, domain2id=domain2id) train_features = preprocessor.convert_examples_to_features( train_examples, word_dropout=args.word_dropout) dev_features = preprocessor.convert_examples_to_features(dev_examples, word_dropout=0.0) train_dataset = WOSDataset(features=train_features) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader( train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=preprocessor.collate_fn, num_workers=args.num_workers, ) # initialize model: update embedding size & initailize weights model_config = BertConfig.from_pretrained(args.pretrained_name_or_path) model_config.dropout = 0.1 model_config.vocab_size = len(tokenizer) model = SomDST(model_config, n_op=args.n_op, n_domain=args.n_domain, update_id=args.update_id) # model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02) # model.encoder.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02) # model.encoder.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02) model.to(device) print("Model is initialized") # num_train_steps = int(len(train_features) / args.train_batch_size * args.epochs) num_train_steps = len(train_dataloader) * args.epochs # no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] # enc_param_optimizer = list(model.encoder.named_parameters()) # enc_optimizer_grouped_parameters = [ # { # "params": [ # p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay) # ], # "weight_decay": 0.01, # }, # { # "params": [ # p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay) # ], # "weight_decay": 0.0, # }, # ] # initialize optimizer & scheduler # enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr, eps=args.adam_epsilon) enc_param_optimizer = list(model.encoder.parameters()) enc_optimizer = AdamW(enc_param_optimizer, lr=args.enc_lr, eps=args.adam_epsilon) enc_scheduler = get_linear_schedule_with_warmup( optimizer=enc_optimizer, num_warmup_steps=0.1, num_training_steps=num_train_steps) dec_param_optimizer = list(model.decoder.parameters()) dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr, eps=args.adam_epsilon) dec_scheduler = get_linear_schedule_with_warmup( optimizer=dec_optimizer, num_warmup_steps=0.1, num_training_steps=num_train_steps) criterion = nn.CrossEntropyLoss() rng = random.Random(args.seed) # save experiment settings json.dump( vars(args), open(f"{args.model_dir}/{args.model_fold}/exp_config.json", "w"), indent=2, ensure_ascii=False, ) json.dump( slot_meta, open(f"{args.model_dir}/{args.model_fold}/slot_meta.json", "w"), indent=2, ensure_ascii=False, ) best_score = { "epoch": 0, "joint_acc": 0, # "op_acc": 0, "slot_acc": 0, "slot_f1": 0, "op_acc": 0, "op_f1": 0, } for epoch in range(args.epochs): print(f'Epoch #{epoch}') batch_loss = [] model.train() for step, batch in tqdm(enumerate(train_dataloader), desc='[Step]'): batch = [ b.to(device) if not isinstance(b, int) else b for b in batch ] ( input_ids, input_mask, segment_ids, state_position_ids, op_ids, domain_ids, gen_ids, max_value, max_update, ) = batch # teacher forcing for generation(decoder) teacher = gen_ids if rng.random( ) < args.teacher_forcing_ratio else None domain_scores, state_scores, gen_scores = model( input_ids=input_ids, token_type_ids=segment_ids, state_positions=state_position_ids, attention_mask=input_mask, max_value=max_value, op_ids=op_ids, max_update=max_update, teacher=teacher, ) loss_s = criterion(state_scores.view(-1, len(op2id)), op_ids.view(-1)) loss_g = masked_cross_entropy_for_value( logits=gen_scores.contiguous(), target=gen_ids.contiguous(), pad_idx=tokenizer.pad_token_id, ) loss = loss_s + loss_g if args.exclude_domain is not True: loss_d = criterion(domain_scores.view(-1, args.n_domain), domain_ids.view(-1)) loss = loss + loss_d batch_loss.append(loss.item()) loss.backward() enc_optimizer.step() enc_scheduler.step() dec_optimizer.step() dec_scheduler.step() model.zero_grad() for learning_rate in enc_scheduler.get_lr(): wandb.log({"encoder_learning_rate": learning_rate}) for learning_rate in dec_scheduler.get_lr(): wandb.log({"decoder_learning_rate": learning_rate}) if step % 100 == 0: if args.exclude_domain is not True: print( f"[{epoch+1}/{args.epochs}] [{step}/{len(train_dataloader)}] mean_loss : {np.mean(batch_loss):.3f}, state_loss : {loss_s.item():.3f}, gen_loss : {loss_g.item():.3f}, dom_loss : {loss_d.item():.3f}" ) wandb.log({ "epoch": epoch, "Train epoch loss": np.mean(batch_loss), "Train epoch state loss": loss_s.item(), "Train epoch generation loss": loss_g.item(), "Train epoch domain loss": loss_d.item(), }) else: print( f"[{epoch+1}/{args.epochs}] [{step}/{len(train_dataloader)}] mean_loss : {np.mean(batch_loss):.3f}, state_loss : {loss_s.item():.3f}, gen_loss : {loss_g.item():.3f}" ) wandb.log({ "epoch": epoch, "Train epoch loss": np.mean(batch_loss), "Train epoch state loss": loss_s.item(), "Train epoch generation loss": loss_g.item(), }) batch_loss = [] # evaluation for each epoch eval_res = model_evaluation( model, dev_features, tokenizer, slot_meta, domain2id, epoch + 1, args.op_code, ) if eval_res["joint_acc"] > best_score["joint_acc"]: print("Update Best checkpoint!") best_score = eval_res best_checkpoint = best_score["epoch"] wandb.log({ "epoch": best_score["epoch"], "Best joint goal accuracy": best_score["joint_acc"], "Best turn slot accuracy": best_score["slot_acc"], "Best turn slot f1": best_score["slot_f1"], "Best operation accucay": best_score["op_acc"], "Best operation f1": best_score["op_f1"], }) # save phase model_to_save = model.module if hasattr(model, "module") else model torch.save( model_to_save.state_dict(), f"{args.model_dir}/{args.model_fold}/model-best.bin", ) print("Best Score : ", best_score) print("\n") print(f"Best checkpoint: {args.model_dir}/model-best.bin")