def train(data, opt, fold_idx): model = SeqModel(data, opt) optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.l2) if opt.tune_wordemb == False: my_utils.freeze_net(model.word_hidden.wordrep.word_embedding) best_dev_f = -10 best_dev_p = -10 best_dev_r = -10 bad_counter = 0 for idx in range(opt.iter): epoch_start = time.time() if opt.elmo: my_utils.shuffle(data.train_texts, data.train_Ids) else: random.shuffle(data.train_Ids) model.train() model.zero_grad() batch_size = opt.batch_size train_num = len(data.train_Ids) total_batch = train_num // batch_size + 1 for batch_id in range(total_batch): start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > train_num: end = train_num instance = data.train_Ids[start:end] if opt.elmo: instance_text = data.train_texts[start:end] else: instance_text = None if not instance: continue batch_word, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, batch_features, batch_text = batchify_with_label( data, instance, instance_text, opt.gpu) loss, tag_seq = model.neg_log_likelihood_loss( batch_word, batch_wordlen, batch_char, batch_charlen, batch_charrecover, batch_label, mask, batch_features, batch_text) loss.backward() if opt.gradient_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), opt.gradient_clip) optimizer.step() model.zero_grad() epoch_finish = time.time() logging.info("epoch: %s training finished. Time: %.2fs" % (idx, epoch_finish - epoch_start)) if opt.dev_file: _, _, p, r, f, _, _ = evaluate(data, opt, model, "dev", True) logging.info("Dev: p: %.4f, r: %.4f, f: %.4f" % (p, r, f)) else: f = best_dev_f if f > best_dev_f: logging.info("Exceed previous best f score on dev: %.4f" % (best_dev_f)) if fold_idx is None: torch.save(model.state_dict(), os.path.join(opt.output, "model.pkl")) else: torch.save( model.state_dict(), os.path.join(opt.output, "model_{}.pkl".format(fold_idx + 1))) best_dev_f = f best_dev_p = p best_dev_r = r # if opt.test_file: # _, _, p, r, f, _, _ = evaluate(data, opt, model, "test", True, opt.nbest) # logging.info("Test: p: %.4f, r: %.4f, f: %.4f" % (p, r, f)) bad_counter = 0 else: bad_counter += 1 if len(opt.dev_file) != 0 and bad_counter >= opt.patience: logging.info('Early Stop!') break logging.info("train finished") if len(opt.dev_file) == 0: torch.save(model.state_dict(), os.path.join(opt.output, "model.pkl")) return best_dev_p, best_dev_r, best_dev_f
def train(train_data, dev_data, test_data, d, dictionary, dictionary_reverse, opt, fold_idx, isMeddra_dict): logging.info("train the vsm-based normalization model ...") external_train_data = [] if d.config.get('norm_ext_corpus') is not None: for k, v in d.config['norm_ext_corpus'].items(): if k == 'tac': external_train_data.extend( load_data_fda(v['path'], True, v.get('types'), v.get('types'), False, True)) else: raise RuntimeError("not support external corpus") if len(external_train_data) != 0: train_data.extend(external_train_data) logging.info("build alphabet ...") word_alphabet = Alphabet('word') norm_utils.build_alphabet_from_dict(word_alphabet, dictionary, isMeddra_dict) norm_utils.build_alphabet(word_alphabet, train_data) if opt.dev_file: norm_utils.build_alphabet(word_alphabet, dev_data) if opt.test_file: norm_utils.build_alphabet(word_alphabet, test_data) norm_utils.fix_alphabet(word_alphabet) logging.info("alphabet size {}".format(word_alphabet.size())) if d.config.get('norm_emb') is not None: logging.info("load pretrained word embedding ...") pretrain_word_embedding, word_emb_dim = build_pretrain_embedding( d.config.get('norm_emb'), word_alphabet, opt.word_emb_dim, False) word_embedding = nn.Embedding(word_alphabet.size(), word_emb_dim, padding_idx=0) word_embedding.weight.data.copy_( torch.from_numpy(pretrain_word_embedding)) embedding_dim = word_emb_dim else: logging.info("randomly initialize word embedding ...") word_embedding = nn.Embedding(word_alphabet.size(), d.word_emb_dim, padding_idx=0) word_embedding.weight.data.copy_( torch.from_numpy( random_embedding(word_alphabet.size(), d.word_emb_dim))) embedding_dim = d.word_emb_dim dict_alphabet = Alphabet('dict') norm_utils.init_dict_alphabet(dict_alphabet, dictionary) norm_utils.fix_alphabet(dict_alphabet) logging.info("init_vector_for_dict") poses, poses_lengths = init_vector_for_dict(word_alphabet, dict_alphabet, dictionary, isMeddra_dict) vsm_model = VsmNormer(word_alphabet, word_embedding, embedding_dim, dict_alphabet, poses, poses_lengths) logging.info("generate instances for training ...") train_X = [] train_Y = [] for doc in train_data: if isMeddra_dict: temp_X, temp_Y = generate_instances(doc.entities, word_alphabet, dict_alphabet) else: temp_X, temp_Y = generate_instances_ehr(doc.entities, word_alphabet, dict_alphabet, dictionary_reverse) train_X.extend(temp_X) train_Y.extend(temp_Y) train_loader = DataLoader(MyDataset(train_X, train_Y), opt.batch_size, shuffle=True, collate_fn=my_collate) optimizer = optim.Adam(vsm_model.parameters(), lr=opt.lr, weight_decay=opt.l2) if opt.tune_wordemb == False: freeze_net(vsm_model.word_embedding) if d.config['norm_vsm_pretrain'] == '1': dict_pretrain(dictionary, dictionary_reverse, d, isMeddra_dict, optimizer, vsm_model) best_dev_f = -10 best_dev_p = -10 best_dev_r = -10 bad_counter = 0 logging.info("start training ...") for idx in range(opt.iter): epoch_start = time.time() vsm_model.train() train_iter = iter(train_loader) num_iter = len(train_loader) sum_loss = 0 correct, total = 0, 0 for i in range(num_iter): x, lengths, y = next(train_iter) l, y_pred = vsm_model.forward_train(x, lengths, y) sum_loss += l.item() l.backward() if opt.gradient_clip > 0: torch.nn.utils.clip_grad_norm_(vsm_model.parameters(), opt.gradient_clip) optimizer.step() vsm_model.zero_grad() total += y.size(0) _, pred = torch.max(y_pred, 1) correct += (pred == y).sum().item() epoch_finish = time.time() accuracy = 100.0 * correct / total logging.info( "epoch: %s training finished. Time: %.2fs. loss: %.4f Accuracy %.2f" % (idx, epoch_finish - epoch_start, sum_loss / num_iter, accuracy)) if opt.dev_file: p, r, f = norm_utils.evaluate(dev_data, dictionary, dictionary_reverse, vsm_model, None, None, d, isMeddra_dict) logging.info("Dev: p: %.4f, r: %.4f, f: %.4f" % (p, r, f)) else: f = best_dev_f if f > best_dev_f: logging.info("Exceed previous best f score on dev: %.4f" % (best_dev_f)) if fold_idx is None: torch.save(vsm_model, os.path.join(opt.output, "vsm.pkl")) else: torch.save( vsm_model, os.path.join(opt.output, "vsm_{}.pkl".format(fold_idx + 1))) best_dev_f = f best_dev_p = p best_dev_r = r bad_counter = 0 else: bad_counter += 1 if len(opt.dev_file) != 0 and bad_counter >= opt.patience: logging.info('Early Stop!') break logging.info("train finished") if len(opt.dev_file) == 0: torch.save(vsm_model, os.path.join(opt.output, "vsm.pkl")) return best_dev_p, best_dev_r, best_dev_f
def joint_train(data, old_data, opt): if not os.path.exists(opt.output): os.makedirs(opt.output) if opt.pretrained_model_dir != 'None': seq_model = SeqModel(data) if opt.test_in_cpu: seq_model.load_state_dict( torch.load(os.path.join(opt.pretrained_model_dir, 'ner_model.pkl'), map_location='cpu')) else: seq_model.load_state_dict( torch.load( os.path.join(opt.pretrained_model_dir, 'ner_model.pkl'))) if (data.label_alphabet_size != seq_model.crf.tagset_size)\ or (data.HP_hidden_dim != seq_model.hidden2tag.weight.size(1)): raise RuntimeError("ner_model not compatible") seq_wordseq = WordSequence(data, False, True, True, True) if ((data.word_emb_dim != seq_wordseq.wordrep.word_embedding.embedding_dim)\ or (data.char_emb_dim != seq_wordseq.wordrep.char_feature.char_embeddings.embedding_dim)\ or (data.feature_emb_dims[0] != seq_wordseq.wordrep.feature_embedding_dims[0])\ or (data.feature_emb_dims[1] != seq_wordseq.wordrep.feature_embedding_dims[1])): raise RuntimeError("ner_wordseq not compatible") old_seq_wordseq = WordSequence(old_data, False, True, True, True) if opt.test_in_cpu: old_seq_wordseq.load_state_dict( torch.load(os.path.join(opt.pretrained_model_dir, 'ner_wordseq.pkl'), map_location='cpu')) else: old_seq_wordseq.load_state_dict( torch.load( os.path.join(opt.pretrained_model_dir, 'ner_wordseq.pkl'))) # sd = old_seq_wordseq.lstm.state_dict() for old, new in zip(old_seq_wordseq.lstm.parameters(), seq_wordseq.lstm.parameters()): new.data.copy_(old) vocab_size = old_seq_wordseq.wordrep.word_embedding.num_embeddings seq_wordseq.wordrep.word_embedding.weight.data[ 0: vocab_size, :] = old_seq_wordseq.wordrep.word_embedding.weight.data[ 0:vocab_size, :] vocab_size = old_seq_wordseq.wordrep.char_feature.char_embeddings.num_embeddings seq_wordseq.wordrep.char_feature.char_embeddings.weight.data[ 0: vocab_size, :] = old_seq_wordseq.wordrep.char_feature.char_embeddings.weight.data[ 0:vocab_size, :] for i, feature_embedding in enumerate( old_seq_wordseq.wordrep.feature_embeddings): vocab_size = feature_embedding.num_embeddings seq_wordseq.wordrep.feature_embeddings[i].weight.data[ 0:vocab_size, :] = feature_embedding.weight.data[ 0:vocab_size, :] # for word in data.word_alphabet.iteritems(): # # old_seq_wordseq.wordrep.word_embedding.weight.data data.word_alphabet.get_index(word) # classify_wordseq = WordSequence(data, True, False, True, False) if ((data.word_emb_dim != classify_wordseq.wordrep.word_embedding.embedding_dim)\ or (data.re_feature_emb_dims[data.re_feature_name2id['[POSITION]']] != classify_wordseq.wordrep.position1_emb.embedding_dim)\ or (data.feature_emb_dims[1] != classify_wordseq.wordrep.feature_embedding_dims[0])): raise RuntimeError("re_wordseq not compatible") old_classify_wordseq = WordSequence(old_data, True, False, True, False) if opt.test_in_cpu: old_classify_wordseq.load_state_dict( torch.load(os.path.join(opt.pretrained_model_dir, 're_wordseq.pkl'), map_location='cpu')) else: old_classify_wordseq.load_state_dict( torch.load( os.path.join(opt.pretrained_model_dir, 're_wordseq.pkl'))) for old, new in zip(old_classify_wordseq.lstm.parameters(), classify_wordseq.lstm.parameters()): new.data.copy_(old) vocab_size = old_classify_wordseq.wordrep.word_embedding.num_embeddings classify_wordseq.wordrep.word_embedding.weight.data[ 0: vocab_size, :] = old_classify_wordseq.wordrep.word_embedding.weight.data[ 0:vocab_size, :] vocab_size = old_classify_wordseq.wordrep.position1_emb.num_embeddings classify_wordseq.wordrep.position1_emb.weight.data[ 0: vocab_size, :] = old_classify_wordseq.wordrep.position1_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_wordseq.wordrep.position2_emb.num_embeddings classify_wordseq.wordrep.position2_emb.weight.data[ 0: vocab_size, :] = old_classify_wordseq.wordrep.position2_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_wordseq.wordrep.feature_embeddings[ 0].num_embeddings classify_wordseq.wordrep.feature_embeddings[0].weight.data[ 0:vocab_size, :] = old_classify_wordseq.wordrep.feature_embeddings[ 0].weight.data[0:vocab_size, :] classify_model = ClassifyModel(data) old_classify_model = ClassifyModel(old_data) if opt.test_in_cpu: old_classify_model.load_state_dict( torch.load(os.path.join(opt.pretrained_model_dir, 're_model.pkl'), map_location='cpu')) else: old_classify_model.load_state_dict( torch.load( os.path.join(opt.pretrained_model_dir, 're_model.pkl'))) if (data.re_feature_alphabet_sizes[data.re_feature_name2id['[RELATION]']] != old_classify_model.linear.weight.size(0) or (data.re_feature_emb_dims[data.re_feature_name2id['[ENTITY_TYPE]']] != old_classify_model.entity_type_emb.embedding_dim) \ or (data.re_feature_emb_dims[ data.re_feature_name2id['[ENTITY]']] != old_classify_model.entity_emb.embedding_dim) or (data.re_feature_emb_dims[ data.re_feature_name2id['[TOKEN_NUM]']] != old_classify_model.tok_num_betw_emb.embedding_dim) \ or (data.re_feature_emb_dims[ data.re_feature_name2id['[ENTITY_NUM]']] != old_classify_model.et_num_emb.embedding_dim) \ ): raise RuntimeError("re_model not compatible") vocab_size = old_classify_model.entity_type_emb.num_embeddings classify_model.entity_type_emb.weight.data[ 0:vocab_size, :] = old_classify_model.entity_type_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_model.entity_emb.num_embeddings classify_model.entity_emb.weight.data[ 0:vocab_size, :] = old_classify_model.entity_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_model.tok_num_betw_emb.num_embeddings classify_model.tok_num_betw_emb.weight.data[ 0:vocab_size, :] = old_classify_model.tok_num_betw_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_model.et_num_emb.num_embeddings classify_model.et_num_emb.weight.data[ 0:vocab_size, :] = old_classify_model.et_num_emb.weight.data[ 0:vocab_size, :] else: seq_model = SeqModel(data) seq_wordseq = WordSequence(data, False, True, True, True) classify_wordseq = WordSequence(data, True, False, True, False) classify_model = ClassifyModel(data) iter_parameter = itertools.chain( *map(list, [seq_wordseq.parameters(), seq_model.parameters()])) seq_optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) iter_parameter = itertools.chain(*map( list, [classify_wordseq.parameters(), classify_model.parameters()])) classify_optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) if data.tune_wordemb == False: my_utils.freeze_net(seq_wordseq.wordrep.word_embedding) my_utils.freeze_net(classify_wordseq.wordrep.word_embedding) re_X_positive = [] re_Y_positive = [] re_X_negative = [] re_Y_negative = [] relation_vocab = data.re_feature_alphabets[ data.re_feature_name2id['[RELATION]']] my_collate = my_utils.sorted_collate1 for i in range(len(data.re_train_X)): x = data.re_train_X[i] y = data.re_train_Y[i] if y != relation_vocab.get_index("</unk>"): re_X_positive.append(x) re_Y_positive.append(y) else: re_X_negative.append(x) re_Y_negative.append(y) re_dev_loader = DataLoader(my_utils.RelationDataset( data.re_dev_X, data.re_dev_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) # re_test_loader = DataLoader(my_utils.RelationDataset(data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) best_ner_score = -1 best_re_score = -1 count_performance_not_grow = 0 for idx in range(data.HP_iteration): epoch_start = time.time() seq_wordseq.train() seq_wordseq.zero_grad() seq_model.train() seq_model.zero_grad() classify_wordseq.train() classify_wordseq.zero_grad() classify_model.train() classify_model.zero_grad() batch_size = data.HP_batch_size random.shuffle(data.train_Ids) ner_train_num = len(data.train_Ids) ner_total_batch = ner_train_num // batch_size + 1 re_train_loader, re_train_iter = makeRelationDataset( re_X_positive, re_Y_positive, re_X_negative, re_Y_negative, data.unk_ratio, True, my_collate, data.HP_batch_size) re_total_batch = len(re_train_loader) total_batch = max(ner_total_batch, re_total_batch) min_batch = min(ner_total_batch, re_total_batch) for batch_id in range(total_batch): if batch_id < ner_total_batch: start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > ner_train_num: end = ner_train_num instance = data.train_Ids[start:end] batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \ batch_permute_label = batchify_with_label(instance, data.HP_gpu) hidden = seq_wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) hidden_adv = None loss, tag_seq = seq_model.neg_log_likelihood_loss( hidden, hidden_adv, batch_label, mask) loss.backward() seq_optimizer.step() seq_wordseq.zero_grad() seq_model.zero_grad() if batch_id < re_total_batch: [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1( re_train_loader, re_train_iter) hidden = classify_wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = None loss, pred = classify_model.neg_log_likelihood_loss( hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() classify_optimizer.step() classify_wordseq.zero_grad() classify_model.zero_grad() epoch_finish = time.time() logging.info("epoch: %s training finished. Time: %.2fs" % (idx, epoch_finish - epoch_start)) _, _, _, _, ner_score, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "dev") logging.info("ner evaluate: f: %.4f" % (ner_score)) if ner_score > best_ner_score: logging.info("new best score: ner: %.4f" % (ner_score)) best_ner_score = ner_score torch.save(seq_wordseq.state_dict(), os.path.join(opt.output, 'ner_wordseq.pkl')) torch.save(seq_model.state_dict(), os.path.join(opt.output, 'ner_model.pkl')) count_performance_not_grow = 0 # _, _, _, _, test_ner_score, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "test") # logging.info("ner evaluate on test: f: %.4f" % (test_ner_score)) else: count_performance_not_grow += 1 re_score = relation_extraction.evaluate(classify_wordseq, classify_model, re_dev_loader) logging.info("re evaluate: f: %.4f" % (re_score)) if re_score > best_re_score: logging.info("new best score: re: %.4f" % (re_score)) best_re_score = re_score torch.save(classify_wordseq.state_dict(), os.path.join(opt.output, 're_wordseq.pkl')) torch.save(classify_model.state_dict(), os.path.join(opt.output, 're_model.pkl')) count_performance_not_grow = 0 # test_re_score = relation_extraction.evaluate(classify_wordseq, classify_model, re_test_loader) # logging.info("re evaluate on test: f: %.4f" % (test_re_score)) else: count_performance_not_grow += 1 if count_performance_not_grow > 2 * data.patience: logging.info("early stop") break logging.info("train finished")
def train(data, dir): my_collate = my_utils.sorted_collate1 train_loader, train_iter = makeDatasetWithoutUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], True, my_collate, data.HP_batch_size) num_iter = len(train_loader) unk_loader, unk_iter = makeDatasetUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], my_collate, data.unk_ratio, data.HP_batch_size) test_loader = DataLoader(my_utils.RelationDataset(data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) wordseq = WordSequence(data, True, False, False) model = ClassifyModel(data) if torch.cuda.is_available(): model = model.cuda(data.HP_gpu) if opt.self_adv == 'grad': wordseq_adv = WordSequence(data, True, False, False) elif opt.self_adv == 'label': wordseq_adv = WordSequence(data, True, False, False) model_adv = ClassifyModel(data) if torch.cuda.is_available(): model_adv = model_adv.cuda(data.HP_gpu) else: wordseq_adv = None if opt.self_adv == 'grad': iter_parameter = itertools.chain( *map(list, [wordseq.parameters(), wordseq_adv.parameters(), model.parameters()])) optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) elif opt.self_adv == 'label': iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()])) optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) iter_parameter = itertools.chain(*map(list, [wordseq_adv.parameters(), model_adv.parameters()])) optimizer_adv = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) else: iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()])) optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) if data.tune_wordemb == False: my_utils.freeze_net(wordseq.wordrep.word_embedding) if opt.self_adv != 'no': my_utils.freeze_net(wordseq_adv.wordrep.word_embedding) best_acc = 0.0 logging.info("start training ...") for epoch in range(data.max_epoch): wordseq.train() wordseq.zero_grad() model.train() model.zero_grad() if opt.self_adv == 'grad': wordseq_adv.train() wordseq_adv.zero_grad() elif opt.self_adv == 'label': wordseq_adv.train() wordseq_adv.zero_grad() model_adv.train() model_adv.zero_grad() correct, total = 0, 0 for i in range(num_iter): [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(train_loader, train_iter) if opt.self_adv == 'grad': hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() my_utils.reverse_grad(wordseq_adv) optimizer.step() wordseq.zero_grad() wordseq_adv.zero_grad() model.zero_grad() elif opt.self_adv == 'label' : wordseq.unfreeze_net() wordseq_adv.freeze_net() hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() optimizer.step() wordseq.zero_grad() wordseq_adv.zero_grad() model.zero_grad() wordseq.freeze_net() wordseq_adv.unfreeze_net() hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) loss_adv, _ = model_adv.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets_permute) loss_adv.backward() optimizer_adv.step() wordseq.zero_grad() wordseq_adv.zero_grad() model_adv.zero_grad() else: hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = None loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() optimizer.step() wordseq.zero_grad() model.zero_grad() total += targets.size(0) correct += (pred == targets).sum().item() [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(unk_loader, unk_iter) hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = None loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() optimizer.step() wordseq.zero_grad() model.zero_grad() unk_loader, unk_iter = makeDatasetUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], my_collate, data.unk_ratio, data.HP_batch_size) logging.info('epoch {} end'.format(epoch)) logging.info('Train Accuracy: {}%'.format(100.0 * correct / total)) test_accuracy = evaluate(wordseq, model, test_loader) # test_accuracy = evaluate(m, test_loader) logging.info('Test Accuracy: {}%'.format(test_accuracy)) if test_accuracy > best_acc: best_acc = test_accuracy torch.save(wordseq.state_dict(), os.path.join(dir, 'wordseq.pkl')) torch.save(model.state_dict(), '{}/model.pkl'.format(dir)) if opt.self_adv == 'grad': torch.save(wordseq_adv.state_dict(), os.path.join(dir, 'wordseq_adv.pkl')) elif opt.self_adv == 'label': torch.save(wordseq_adv.state_dict(), os.path.join(dir, 'wordseq_adv.pkl')) torch.save(model_adv.state_dict(), os.path.join(dir, 'model_adv.pkl')) logging.info('New best accuracy: {}'.format(best_acc)) logging.info("training completed")
def train(train_data, dev_data, test_data, d, dictionary, dictionary_reverse, opt, fold_idx, isMeddra_dict): logging.info("train the ensemble normalization model ...") external_train_data = [] if d.config.get('norm_ext_corpus') is not None: for k, v in d.config['norm_ext_corpus'].items(): if k == 'tac': external_train_data.extend( load_data_fda(v['path'], True, v.get('types'), v.get('types'), False, True)) else: raise RuntimeError("not support external corpus") if len(external_train_data) != 0: train_data.extend(external_train_data) logging.info("build alphabet ...") word_alphabet = Alphabet('word') norm_utils.build_alphabet_from_dict(word_alphabet, dictionary, isMeddra_dict) norm_utils.build_alphabet(word_alphabet, train_data) if opt.dev_file: norm_utils.build_alphabet(word_alphabet, dev_data) if opt.test_file: norm_utils.build_alphabet(word_alphabet, test_data) norm_utils.fix_alphabet(word_alphabet) if d.config.get('norm_emb') is not None: logging.info("load pretrained word embedding ...") pretrain_word_embedding, word_emb_dim = build_pretrain_embedding( d.config.get('norm_emb'), word_alphabet, opt.word_emb_dim, False) word_embedding = nn.Embedding(word_alphabet.size(), word_emb_dim, padding_idx=0) word_embedding.weight.data.copy_( torch.from_numpy(pretrain_word_embedding)) embedding_dim = word_emb_dim else: logging.info("randomly initialize word embedding ...") word_embedding = nn.Embedding(word_alphabet.size(), d.word_emb_dim, padding_idx=0) word_embedding.weight.data.copy_( torch.from_numpy( random_embedding(word_alphabet.size(), d.word_emb_dim))) embedding_dim = d.word_emb_dim dict_alphabet = Alphabet('dict') norm_utils.init_dict_alphabet(dict_alphabet, dictionary) norm_utils.fix_alphabet(dict_alphabet) # rule logging.info("init rule-based normer") multi_sieve.init(opt, train_data, d, dictionary, dictionary_reverse, isMeddra_dict) if opt.ensemble == 'learn': logging.info("init ensemble normer") poses = vsm.init_vector_for_dict(word_alphabet, dict_alphabet, dictionary, isMeddra_dict) ensemble_model = Ensemble(word_alphabet, word_embedding, embedding_dim, dict_alphabet, poses) if pretrain_neural_model is not None: ensemble_model.neural_linear.weight.data.copy_( pretrain_neural_model.linear.weight.data) if pretrain_vsm_model is not None: ensemble_model.vsm_linear.weight.data.copy_( pretrain_vsm_model.linear.weight.data) ensemble_train_X = [] ensemble_train_Y = [] for doc in train_data: temp_X, temp_Y = generate_instances(doc, word_alphabet, dict_alphabet, dictionary, dictionary_reverse, isMeddra_dict) ensemble_train_X.extend(temp_X) ensemble_train_Y.extend(temp_Y) ensemble_train_loader = DataLoader(MyDataset(ensemble_train_X, ensemble_train_Y), opt.batch_size, shuffle=True, collate_fn=my_collate) ensemble_optimizer = optim.Adam(ensemble_model.parameters(), lr=opt.lr, weight_decay=opt.l2) if opt.tune_wordemb == False: freeze_net(ensemble_model.word_embedding) else: # vsm logging.info("init vsm-based normer") poses = vsm.init_vector_for_dict(word_alphabet, dict_alphabet, dictionary, isMeddra_dict) # alphabet can share between vsm and neural since they don't change # but word_embedding cannot vsm_model = vsm.VsmNormer(word_alphabet, copy.deepcopy(word_embedding), embedding_dim, dict_alphabet, poses) vsm_train_X = [] vsm_train_Y = [] for doc in train_data: if isMeddra_dict: temp_X, temp_Y = vsm.generate_instances( doc.entities, word_alphabet, dict_alphabet) else: temp_X, temp_Y = vsm.generate_instances_ehr( doc.entities, word_alphabet, dict_alphabet, dictionary_reverse) vsm_train_X.extend(temp_X) vsm_train_Y.extend(temp_Y) vsm_train_loader = DataLoader(vsm.MyDataset(vsm_train_X, vsm_train_Y), opt.batch_size, shuffle=True, collate_fn=vsm.my_collate) vsm_optimizer = optim.Adam(vsm_model.parameters(), lr=opt.lr, weight_decay=opt.l2) if opt.tune_wordemb == False: freeze_net(vsm_model.word_embedding) if d.config['norm_vsm_pretrain'] == '1': vsm.dict_pretrain(dictionary, dictionary_reverse, d, True, vsm_optimizer, vsm_model) # neural logging.info("init neural-based normer") neural_model = norm_neural.NeuralNormer(word_alphabet, copy.deepcopy(word_embedding), embedding_dim, dict_alphabet) neural_train_X = [] neural_train_Y = [] for doc in train_data: if isMeddra_dict: temp_X, temp_Y = norm_neural.generate_instances( doc.entities, word_alphabet, dict_alphabet) else: temp_X, temp_Y = norm_neural.generate_instances_ehr( doc.entities, word_alphabet, dict_alphabet, dictionary_reverse) neural_train_X.extend(temp_X) neural_train_Y.extend(temp_Y) neural_train_loader = DataLoader(norm_neural.MyDataset( neural_train_X, neural_train_Y), opt.batch_size, shuffle=True, collate_fn=norm_neural.my_collate) neural_optimizer = optim.Adam(neural_model.parameters(), lr=opt.lr, weight_decay=opt.l2) if opt.tune_wordemb == False: freeze_net(neural_model.word_embedding) if d.config['norm_neural_pretrain'] == '1': neural_model.dict_pretrain(dictionary, dictionary_reverse, d, True, neural_optimizer, neural_model) best_dev_f = -10 best_dev_p = -10 best_dev_r = -10 bad_counter = 0 logging.info("start training ...") for idx in range(opt.iter): epoch_start = time.time() if opt.ensemble == 'learn': ensemble_model.train() ensemble_train_iter = iter(ensemble_train_loader) ensemble_num_iter = len(ensemble_train_loader) for i in range(ensemble_num_iter): x, rules, lengths, y = next(ensemble_train_iter) y_pred = ensemble_model.forward(x, rules, lengths) l = ensemble_model.loss(y_pred, y) l.backward() if opt.gradient_clip > 0: torch.nn.utils.clip_grad_norm_(ensemble_model.parameters(), opt.gradient_clip) ensemble_optimizer.step() ensemble_model.zero_grad() else: vsm_model.train() vsm_train_iter = iter(vsm_train_loader) vsm_num_iter = len(vsm_train_loader) for i in range(vsm_num_iter): x, lengths, y = next(vsm_train_iter) l, _ = vsm_model.forward_train(x, lengths, y) l.backward() if opt.gradient_clip > 0: torch.nn.utils.clip_grad_norm_(vsm_model.parameters(), opt.gradient_clip) vsm_optimizer.step() vsm_model.zero_grad() neural_model.train() neural_train_iter = iter(neural_train_loader) neural_num_iter = len(neural_train_loader) for i in range(neural_num_iter): x, lengths, y = next(neural_train_iter) y_pred = neural_model.forward(x, lengths) l = neural_model.loss(y_pred, y) l.backward() if opt.gradient_clip > 0: torch.nn.utils.clip_grad_norm_(neural_model.parameters(), opt.gradient_clip) neural_optimizer.step() neural_model.zero_grad() epoch_finish = time.time() logging.info("epoch: %s training finished. Time: %.2fs" % (idx, epoch_finish - epoch_start)) if opt.dev_file: if opt.ensemble == 'learn': # logging.info("weight w1: %.4f, w2: %.4f, w3: %.4f" % (ensemble_model.w1.data.item(), ensemble_model.w2.data.item(), ensemble_model.w3.data.item())) p, r, f = norm_utils.evaluate(dev_data, dictionary, dictionary_reverse, None, None, ensemble_model, d, isMeddra_dict) else: p, r, f = norm_utils.evaluate(dev_data, dictionary, dictionary_reverse, vsm_model, neural_model, None, d, isMeddra_dict) logging.info("Dev: p: %.4f, r: %.4f, f: %.4f" % (p, r, f)) else: f = best_dev_f if f > best_dev_f: logging.info("Exceed previous best f score on dev: %.4f" % (best_dev_f)) if opt.ensemble == 'learn': if fold_idx is None: torch.save(ensemble_model, os.path.join(opt.output, "ensemble.pkl")) else: torch.save( ensemble_model, os.path.join(opt.output, "ensemble_{}.pkl".format(fold_idx + 1))) else: if fold_idx is None: torch.save(vsm_model, os.path.join(opt.output, "vsm.pkl")) torch.save(neural_model, os.path.join(opt.output, "norm_neural.pkl")) else: torch.save( vsm_model, os.path.join(opt.output, "vsm_{}.pkl".format(fold_idx + 1))) torch.save( neural_model, os.path.join(opt.output, "norm_neural_{}.pkl".format(fold_idx + 1))) best_dev_f = f best_dev_p = p best_dev_r = r bad_counter = 0 else: bad_counter += 1 if len(opt.dev_file) != 0 and bad_counter >= opt.patience: logging.info('Early Stop!') break logging.info("train finished") if fold_idx is None: multi_sieve.finalize(True) else: if fold_idx == opt.cross_validation - 1: multi_sieve.finalize(True) else: multi_sieve.finalize(False) if len(opt.dev_file) == 0: if opt.ensemble == 'learn': torch.save(ensemble_model, os.path.join(opt.output, "ensemble.pkl")) else: torch.save(vsm_model, os.path.join(opt.output, "vsm.pkl")) torch.save(neural_model, os.path.join(opt.output, "norm_neural.pkl")) return best_dev_p, best_dev_r, best_dev_f
def train(data, ner_dir, re_dir): task_num = 2 init_alpha = 1.0 / task_num if opt.hidden_num <= 1: step_alpha = 0 else: step_alpha = (1.0 - init_alpha) / (opt.hidden_num - 1) ner_wordrep = WordRep(data, False, True, True, data.use_char) ner_hiddenlist = [] for i in range(opt.hidden_num): if i == 0: input_size = data.word_emb_dim+data.HP_char_hidden_dim+data.feature_emb_dims[data.feature_name2id['[Cap]']]+ \ data.feature_emb_dims[data.feature_name2id['[POS]']] output_size = data.HP_hidden_dim else: input_size = data.HP_hidden_dim output_size = data.HP_hidden_dim temp = HiddenLayer(data, input_size, output_size) ner_hiddenlist.append(temp) seq_model = SeqModel(data) re_wordrep = WordRep(data, True, False, True, False) re_hiddenlist = [] for i in range(opt.hidden_num): if i == 0: input_size = data.word_emb_dim + data.feature_emb_dims[data.feature_name2id['[POS]']]+\ 2*data.re_feature_emb_dims[data.re_feature_name2id['[POSITION]']] output_size = data.HP_hidden_dim else: input_size = data.HP_hidden_dim output_size = data.HP_hidden_dim temp = HiddenLayer(data, input_size, output_size) re_hiddenlist.append(temp) classify_model = ClassifyModel(data) iter_parameter = itertools.chain( *map(list, [ner_wordrep.parameters(), seq_model.parameters()] + [f.parameters() for f in ner_hiddenlist])) ner_optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) iter_parameter = itertools.chain( *map(list, [re_wordrep.parameters(), classify_model.parameters()] + [f.parameters() for f in re_hiddenlist])) re_optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) if data.tune_wordemb == False: my_utils.freeze_net(ner_wordrep.word_embedding) my_utils.freeze_net(re_wordrep.word_embedding) re_X_positive = [] re_Y_positive = [] re_X_negative = [] re_Y_negative = [] relation_vocab = data.re_feature_alphabets[ data.re_feature_name2id['[RELATION]']] my_collate = my_utils.sorted_collate1 for i in range(len(data.re_train_X)): x = data.re_train_X[i] y = data.re_train_Y[i] if y != relation_vocab.get_index("</unk>"): re_X_positive.append(x) re_Y_positive.append(y) else: re_X_negative.append(x) re_Y_negative.append(y) re_test_loader = DataLoader(my_utils.RelationDataset( data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) best_ner_score = -1 best_re_score = -1 for idx in range(data.HP_iteration): epoch_start = time.time() ner_wordrep.train() ner_wordrep.zero_grad() for hidden_layer in ner_hiddenlist: hidden_layer.train() hidden_layer.zero_grad() seq_model.train() seq_model.zero_grad() re_wordrep.train() re_wordrep.zero_grad() for hidden_layer in re_hiddenlist: hidden_layer.train() hidden_layer.zero_grad() classify_model.train() classify_model.zero_grad() batch_size = data.HP_batch_size random.shuffle(data.train_Ids) ner_train_num = len(data.train_Ids) ner_total_batch = ner_train_num // batch_size + 1 re_train_loader, re_train_iter = makeRelationDataset( re_X_positive, re_Y_positive, re_X_negative, re_Y_negative, data.unk_ratio, True, my_collate, data.HP_batch_size) re_total_batch = len(re_train_loader) total_batch = max(ner_total_batch, re_total_batch) min_batch = min(ner_total_batch, re_total_batch) for batch_id in range(total_batch): if batch_id < min_batch: start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > ner_train_num: end = ner_train_num instance = data.train_Ids[start:end] ner_batch_word, ner_batch_features, ner_batch_wordlen, ner_batch_wordrecover, ner_batch_char, ner_batch_charlen, \ ner_batch_charrecover, ner_batch_label, ner_mask, ner_batch_permute_label = batchify_with_label(instance, data.HP_gpu) [re_batch_word, re_batch_features, re_batch_wordlen, re_batch_wordrecover, re_batch_char, re_batch_charlen, re_batch_charrecover, re_position1_seq_tensor, re_position2_seq_tensor, re_e1_token, re_e1_length, re_e2_token, re_e2_length, re_e1_type, re_e2_type, re_tok_num_betw, re_et_num], [re_targets, re_targets_permute] = \ my_utils.endless_get_next_batch_without_rebatch1(re_train_loader, re_train_iter) if ner_batch_word.size(0) != re_batch_word.size(0): continue # if batch size is not equal, we ignore such batch ner_word_rep = ner_wordrep.forward( ner_batch_word, ner_batch_features, ner_batch_wordlen, ner_batch_char, ner_batch_charlen, ner_batch_charrecover, None, None) re_word_rep = re_wordrep.forward( re_batch_word, re_batch_features, re_batch_wordlen, re_batch_char, re_batch_charlen, re_batch_charrecover, re_position1_seq_tensor, re_position2_seq_tensor) alpha = init_alpha ner_hidden = ner_word_rep re_hidden = re_word_rep for i in range(opt.hidden_num): if alpha > 0.99: use_attn = False ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward( ner_hidden, ner_batch_wordlen, use_attn) re_lstm_out, re_att_out = re_hiddenlist[i].forward( re_hidden, re_batch_wordlen, use_attn) ner_hidden = ner_lstm_out re_hidden = re_lstm_out else: use_attn = True ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward( ner_hidden, ner_batch_wordlen, use_attn) re_lstm_out, re_att_out = re_hiddenlist[i].forward( re_hidden, re_batch_wordlen, use_attn) ner_hidden = alpha * ner_lstm_out + ( 1 - alpha) * re_att_out.unsqueeze(1) re_hidden = alpha * re_lstm_out + ( 1 - alpha) * ner_att_out.unsqueeze(1) alpha += step_alpha ner_loss, ner_tag_seq = seq_model.neg_log_likelihood_loss( ner_hidden, ner_batch_label, ner_mask) re_loss, re_pred = classify_model.neg_log_likelihood_loss( re_hidden, re_batch_wordlen, re_e1_token, re_e1_length, re_e2_token, re_e2_length, re_e1_type, re_e2_type, re_tok_num_betw, re_et_num, re_targets) ner_loss.backward(retain_graph=True) re_loss.backward() ner_optimizer.step() re_optimizer.step() ner_wordrep.zero_grad() for hidden_layer in ner_hiddenlist: hidden_layer.zero_grad() seq_model.zero_grad() re_wordrep.zero_grad() for hidden_layer in re_hiddenlist: hidden_layer.zero_grad() classify_model.zero_grad() else: if batch_id < ner_total_batch: start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > ner_train_num: end = ner_train_num instance = data.train_Ids[start:end] batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \ batch_permute_label = batchify_with_label(instance, data.HP_gpu) ner_word_rep = ner_wordrep.forward( batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) ner_hidden = ner_word_rep for i in range(opt.hidden_num): ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward( ner_hidden, batch_wordlen, False) ner_hidden = ner_lstm_out loss, tag_seq = seq_model.neg_log_likelihood_loss( ner_hidden, batch_label, mask) loss.backward() ner_optimizer.step() ner_wordrep.zero_grad() for hidden_layer in ner_hiddenlist: hidden_layer.zero_grad() seq_model.zero_grad() if batch_id < re_total_batch: [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1( re_train_loader, re_train_iter) re_word_rep = re_wordrep.forward( batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) re_hidden = re_word_rep for i in range(opt.hidden_num): re_lstm_out, re_att_out = re_hiddenlist[i].forward( re_hidden, batch_wordlen, False) re_hidden = re_lstm_out loss, pred = classify_model.neg_log_likelihood_loss( re_hidden, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() re_optimizer.step() re_wordrep.zero_grad() for hidden_layer in re_hiddenlist: hidden_layer.zero_grad() classify_model.zero_grad() epoch_finish = time.time() print("epoch: %s training finished. Time: %.2fs" % (idx, epoch_finish - epoch_start)) ner_score = ner_evaluate(data, ner_wordrep, ner_hiddenlist, seq_model, "test") print("ner evaluate: f: %.4f" % (ner_score)) re_score = re_evaluate(re_wordrep, re_hiddenlist, classify_model, re_test_loader) print("re evaluate: f: %.4f" % (re_score)) if ner_score + re_score > best_ner_score + best_re_score: print("new best score: ner: %.4f , re: %.4f" % (ner_score, re_score)) best_ner_score = ner_score best_re_score = re_score torch.save(ner_wordrep.state_dict(), os.path.join(ner_dir, 'wordrep.pkl')) for i, hidden_layer in enumerate(ner_hiddenlist): torch.save(hidden_layer.state_dict(), os.path.join(ner_dir, 'hidden_{}.pkl'.format(i))) torch.save(seq_model.state_dict(), os.path.join(ner_dir, 'model.pkl')) torch.save(re_wordrep.state_dict(), os.path.join(re_dir, 'wordrep.pkl')) for i, hidden_layer in enumerate(re_hiddenlist): torch.save(hidden_layer.state_dict(), os.path.join(re_dir, 'hidden_{}.pkl'.format(i))) torch.save(classify_model.state_dict(), os.path.join(re_dir, 'model.pkl'))
def pipeline(data, ner_dir, re_dir): seq_model = SeqModel(data) seq_wordseq = WordSequence(data, False, True, True, data.use_char) classify_wordseq = WordSequence(data, True, False, True, False) classify_model = ClassifyModel(data) if torch.cuda.is_available(): classify_model = classify_model.cuda(data.HP_gpu) iter_parameter = itertools.chain( *map(list, [seq_wordseq.parameters(), seq_model.parameters()])) seq_optimizer = optim.Adam(iter_parameter, lr=opt.ner_lr, weight_decay=data.HP_l2) iter_parameter = itertools.chain(*map( list, [classify_wordseq.parameters(), classify_model.parameters()])) classify_optimizer = optim.Adam(iter_parameter, lr=opt.re_lr, weight_decay=data.HP_l2) if data.tune_wordemb == False: my_utils.freeze_net(seq_wordseq.wordrep.word_embedding) my_utils.freeze_net(classify_wordseq.wordrep.word_embedding) re_X_positive = [] re_Y_positive = [] re_X_negative = [] re_Y_negative = [] relation_vocab = data.re_feature_alphabets[ data.re_feature_name2id['[RELATION]']] my_collate = my_utils.sorted_collate1 for i in range(len(data.re_train_X)): x = data.re_train_X[i] y = data.re_train_Y[i] if y != relation_vocab.get_index("</unk>"): re_X_positive.append(x) re_Y_positive.append(y) else: re_X_negative.append(x) re_Y_negative.append(y) re_test_loader = DataLoader(my_utils.RelationDataset( data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) best_ner_score = -1 best_re_score = -1 for idx in range(data.HP_iteration): epoch_start = time.time() seq_wordseq.train() seq_wordseq.zero_grad() seq_model.train() seq_model.zero_grad() classify_wordseq.train() classify_wordseq.zero_grad() classify_model.train() classify_model.zero_grad() batch_size = data.HP_batch_size random.shuffle(data.train_Ids) ner_train_num = len(data.train_Ids) ner_total_batch = ner_train_num // batch_size + 1 re_train_loader, re_train_iter = makeRelationDataset( re_X_positive, re_Y_positive, re_X_negative, re_Y_negative, data.unk_ratio, True, my_collate, data.HP_batch_size) re_total_batch = len(re_train_loader) total_batch = max(ner_total_batch, re_total_batch) min_batch = min(ner_total_batch, re_total_batch) for batch_id in range(total_batch): if batch_id < ner_total_batch: start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > ner_train_num: end = ner_train_num instance = data.train_Ids[start:end] batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \ batch_permute_label = batchify_with_label(instance, data.HP_gpu) hidden = seq_wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) hidden_adv = None loss, tag_seq = seq_model.neg_log_likelihood_loss( hidden, hidden_adv, batch_label, mask) loss.backward() seq_optimizer.step() seq_wordseq.zero_grad() seq_model.zero_grad() if batch_id < re_total_batch: [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1( re_train_loader, re_train_iter) hidden = classify_wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = None loss, pred = classify_model.neg_log_likelihood_loss( hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() classify_optimizer.step() classify_wordseq.zero_grad() classify_model.zero_grad() epoch_finish = time.time() print("epoch: %s training finished. Time: %.2fs" % (idx, epoch_finish - epoch_start)) # _, _, _, _, f, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "test") ner_score = ner.evaluate1(data, seq_wordseq, seq_model, "test") print("ner evaluate: f: %.4f" % (ner_score)) re_score = relation_extraction.evaluate(classify_wordseq, classify_model, re_test_loader) print("re evaluate: f: %.4f" % (re_score)) if ner_score + re_score > best_ner_score + best_re_score: print("new best score: ner: %.4f , re: %.4f" % (ner_score, re_score)) best_ner_score = ner_score best_re_score = re_score torch.save(seq_wordseq.state_dict(), os.path.join(ner_dir, 'wordseq.pkl')) torch.save(seq_model.state_dict(), os.path.join(ner_dir, 'model.pkl')) torch.save(classify_wordseq.state_dict(), os.path.join(re_dir, 'wordseq.pkl')) torch.save(classify_model.state_dict(), os.path.join(re_dir, 'model.pkl'))
def train(data): print "Training model..." data.show_data_summary() save_data_name = data.model_dir + ".dset" data.save(save_data_name) model = SeqModel(data) loss_function = nn.NLLLoss() if data.optimizer.lower() == "sgd": optimizer = optim.SGD(model.parameters(), lr=data.HP_lr, momentum=data.HP_momentum, weight_decay=data.HP_l2) elif data.optimizer.lower() == "adagrad": optimizer = optim.Adagrad(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2) elif data.optimizer.lower() == "adadelta": optimizer = optim.Adadelta(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2) elif data.optimizer.lower() == "rmsprop": optimizer = optim.RMSprop(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2) elif data.optimizer.lower() == "adam": optimizer = optim.Adam(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2) else: print("Optimizer illegal: %s" % (data.optimizer)) exit(0) best_dev = -10 if opt.tune_wordemb == False: my_utils.freeze_net(model.word_hidden.wordrep.word_embedding) # data.HP_iteration = 1 ## start training for idx in range(data.HP_iteration): epoch_start = time.time() temp_start = epoch_start print("Epoch: %s/%s" % (idx, data.HP_iteration)) if data.optimizer == "SGD": optimizer = lr_decay(optimizer, idx, data.HP_lr_decay, data.HP_lr) instance_count = 0 sample_id = 0 sample_loss = 0 total_loss = 0 right_token = 0 whole_token = 0 random.shuffle(data.train_Ids) ## set model in train model model.train() model.zero_grad() batch_size = data.HP_batch_size batch_id = 0 train_num = len(data.train_Ids) total_batch = train_num // batch_size + 1 for batch_id in range(total_batch): start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > train_num: end = train_num instance = data.train_Ids[start:end] if not instance: continue batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask = batchify_with_label( instance, data.HP_gpu) instance_count += 1 loss, tag_seq = model.neg_log_likelihood_loss( batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, batch_label, mask) right, whole = predict_check(tag_seq, batch_label, mask) right_token += right whole_token += whole sample_loss += loss.data.item() total_loss += loss.data.item() if end % 500 == 0: temp_time = time.time() temp_cost = temp_time - temp_start temp_start = temp_time print( " Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)) if sample_loss > 1e8 or str(sample_loss) == "nan": print "ERROR: LOSS EXPLOSION (>1e8) ! PLEASE SET PROPER PARAMETERS AND STRUCTURE! EXIT...." exit(0) sys.stdout.flush() sample_loss = 0 loss.backward() optimizer.step() model.zero_grad() temp_time = time.time() temp_cost = temp_time - temp_start print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)) epoch_finish = time.time() epoch_cost = epoch_finish - epoch_start print( "Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s" % (idx, epoch_cost, train_num / epoch_cost, total_loss)) print "totalloss:", total_loss if total_loss > 1e8 or str(total_loss) == "nan": print "ERROR: LOSS EXPLOSION (>1e8) ! PLEASE SET PROPER PARAMETERS AND STRUCTURE! EXIT...." exit(0) # continue speed, acc, p, r, f, _, _ = evaluate(data, model, "dev") dev_finish = time.time() dev_cost = dev_finish - epoch_finish if data.seg: current_score = f print( "Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (dev_cost, speed, acc, p, r, f)) else: current_score = acc print("Dev: time: %.2fs speed: %.2fst/s; acc: %.4f" % (dev_cost, speed, acc)) if current_score > best_dev: if data.seg: print "Exceed previous best f score:", best_dev else: print "Exceed previous best acc score:", best_dev model_name = data.model_dir + ".model" torch.save(model.state_dict(), model_name) best_dev = current_score gc.collect()
def train(data, model_file): print "Training model..." model = SeqModel(data) wordseq = WordSequence(data, False, True, data.use_char) if opt.self_adv == 'grad': wordseq_adv = WordSequence(data, False, True, data.use_char) elif opt.self_adv == 'label': wordseq_adv = WordSequence(data, False, True, data.use_char) model_adv = SeqModel(data) else: wordseq_adv = None if data.optimizer.lower() == "sgd": optimizer = optim.SGD(model.parameters(), lr=data.HP_lr, momentum=data.HP_momentum, weight_decay=data.HP_l2) elif data.optimizer.lower() == "adagrad": optimizer = optim.Adagrad(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2) elif data.optimizer.lower() == "adadelta": optimizer = optim.Adadelta(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2) elif data.optimizer.lower() == "rmsprop": optimizer = optim.RMSprop(model.parameters(), lr=data.HP_lr, weight_decay=data.HP_l2) elif data.optimizer.lower() == "adam": if opt.self_adv == 'grad': iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), wordseq_adv.parameters(), model.parameters()])) optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) elif opt.self_adv == 'label': iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()])) optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) iter_parameter = itertools.chain(*map(list, [wordseq_adv.parameters(), model_adv.parameters()])) optimizer_adv = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) else: iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()])) optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) else: print("Optimizer illegal: %s" % (data.optimizer)) exit(0) best_dev = -10 if data.tune_wordemb == False: my_utils.freeze_net(wordseq.wordrep.word_embedding) if opt.self_adv != 'no': my_utils.freeze_net(wordseq_adv.wordrep.word_embedding) # data.HP_iteration = 1 ## start training for idx in range(data.HP_iteration): epoch_start = time.time() temp_start = epoch_start print("epoch: %s/%s" % (idx, data.HP_iteration)) if data.optimizer == "SGD": optimizer = lr_decay(optimizer, idx, data.HP_lr_decay, data.HP_lr) instance_count = 0 sample_id = 0 sample_loss = 0 total_loss = 0 right_token = 0 whole_token = 0 random.shuffle(data.train_Ids) ## set model in train model wordseq.train() wordseq.zero_grad() if opt.self_adv == 'grad': wordseq_adv.train() wordseq_adv.zero_grad() elif opt.self_adv == 'label': wordseq_adv.train() wordseq_adv.zero_grad() model_adv.train() model_adv.zero_grad() model.train() model.zero_grad() batch_size = data.HP_batch_size batch_id = 0 train_num = len(data.train_Ids) total_batch = train_num // batch_size + 1 for batch_id in range(total_batch): start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > train_num: end = train_num instance = data.train_Ids[start:end] if not instance: continue batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask,\ batch_permute_label = batchify_with_label(instance, data.HP_gpu) instance_count += 1 if opt.self_adv == 'grad': hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,batch_charrecover, None, None) hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) loss, tag_seq = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_label, mask) loss.backward() my_utils.reverse_grad(wordseq_adv) optimizer.step() wordseq.zero_grad() wordseq_adv.zero_grad() model.zero_grad() elif opt.self_adv == 'label' : wordseq.unfreeze_net() wordseq_adv.freeze_net() hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,batch_charrecover, None, None) hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) loss, tag_seq = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_label, mask) loss.backward() optimizer.step() wordseq.zero_grad() wordseq_adv.zero_grad() model.zero_grad() wordseq.freeze_net() wordseq_adv.unfreeze_net() hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen,batch_charrecover, None, None) hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) loss_adv, _ = model_adv.neg_log_likelihood_loss(hidden, hidden_adv, batch_permute_label, mask) loss_adv.backward() optimizer_adv.step() wordseq.zero_grad() wordseq_adv.zero_grad() model_adv.zero_grad() else: hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) hidden_adv = None loss, tag_seq = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_label, mask) loss.backward() optimizer.step() wordseq.zero_grad() model.zero_grad() # right, whole = predict_check(tag_seq, batch_label, mask) # right_token += right # whole_token += whole sample_loss += loss.data.item() total_loss += loss.data.item() if end % 500 == 0: # temp_time = time.time() # temp_cost = temp_time - temp_start # temp_start = temp_time # print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % ( # end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)) if sample_loss > 1e8 or str(sample_loss) == "nan": print "ERROR: LOSS EXPLOSION (>1e8) ! PLEASE SET PROPER PARAMETERS AND STRUCTURE! EXIT...." exit(0) sys.stdout.flush() sample_loss = 0 # temp_time = time.time() # temp_cost = temp_time - temp_start # print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" % ( # end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)) epoch_finish = time.time() epoch_cost = epoch_finish - epoch_start print("epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s" % ( idx, epoch_cost, train_num / epoch_cost, total_loss)) print "totalloss:", total_loss if total_loss > 1e8 or str(total_loss) == "nan": print "ERROR: LOSS EXPLOSION (>1e8) ! PLEASE SET PROPER PARAMETERS AND STRUCTURE! EXIT...." exit(0) # continue speed, acc, p, r, f, _, _ = evaluate(data, wordseq, model, "test") dev_finish = time.time() dev_cost = dev_finish - epoch_finish if data.seg: current_score = f print("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % ( dev_cost, speed, acc, p, r, f)) else: current_score = acc print("Dev: time: %.2fs speed: %.2fst/s; acc: %.4f" % (dev_cost, speed, acc)) if current_score > best_dev: if data.seg: print "Exceed previous best f score:", best_dev else: print "Exceed previous best acc score:", best_dev torch.save(wordseq.state_dict(), os.path.join(model_file, 'wordseq.pkl')) if opt.self_adv == 'grad': torch.save(wordseq_adv.state_dict(), os.path.join(model_file, 'wordseq_adv.pkl')) elif opt.self_adv == 'label': torch.save(wordseq_adv.state_dict(), os.path.join(model_file, 'wordseq_adv.pkl')) torch.save(model_adv.state_dict(), os.path.join(model_file, 'model_adv.pkl')) model_name = os.path.join(model_file, 'model.pkl') torch.save(model.state_dict(), model_name) best_dev = current_score gc.collect()
# batch size always 1 train_loader = DataLoader(MyDataset(instances_train), 1, shuffle=True, collate_fn=my_collate) test_loader = DataLoader(MyDataset(instances_test), 1, shuffle=False, collate_fn=my_collate) optimizer = optim.Adam(vsm_model.parameters(), lr=opt.lr, weight_decay=opt.l2) if opt.tune_wordemb == False: freeze_net(vsm_model.word_embedding) best_acc = -10 bad_counter = 0 logging.info("start training ...") for idx in range(opt.iter): epoch_start = time.time() vsm_model.train() train_iter = iter(train_loader) num_iter = len(train_loader)
def train(other_dir): logging.info("loading ... vocab") word_vocab = pickle.load(open(os.path.join(opt.pretrain, 'word_vocab.pkl'), 'rb')) postag_vocab = pickle.load(open(os.path.join(opt.pretrain, 'postag_vocab.pkl'), 'rb')) relation_vocab = pickle.load(open(os.path.join(opt.pretrain, 'relation_vocab.pkl'), 'rb')) entity_type_vocab = pickle.load(open(os.path.join(opt.pretrain, 'entity_type_vocab.pkl'), 'rb')) entity_vocab = pickle.load(open(os.path.join(opt.pretrain, 'entity_vocab.pkl'), 'rb')) position_vocab1 = pickle.load(open(os.path.join(opt.pretrain, 'position_vocab1.pkl'), 'rb')) position_vocab2 = pickle.load(open(os.path.join(opt.pretrain, 'position_vocab2.pkl'), 'rb')) tok_num_betw_vocab = pickle.load(open(os.path.join(opt.pretrain, 'tok_num_betw_vocab.pkl'), 'rb')) et_num_vocab = pickle.load(open(os.path.join(opt.pretrain, 'et_num_vocab.pkl'), 'rb')) my_collate = my_utils.sorted_collate if opt.model == 'lstm' else my_utils.unsorted_collate # test only on the main domain test_X = pickle.load(open(os.path.join(opt.pretrain, 'test_X.pkl'), 'rb')) test_Y = pickle.load(open(os.path.join(opt.pretrain, 'test_Y.pkl'), 'rb')) test_Other = pickle.load(open(os.path.join(opt.pretrain, 'test_Other.pkl'), 'rb')) logging.info("total test instance {}".format(len(test_Y))) test_loader = DataLoader(my_utils.RelationDataset(test_X, test_Y), opt.batch_size, shuffle=False, collate_fn=my_collate) # drop_last=True # train on the main as well as other domains domains = ['main'] train_loaders, train_iters, unk_loaders, unk_iters = {}, {}, {}, {} train_X = pickle.load(open(os.path.join(opt.pretrain, 'train_X.pkl'), 'rb')) train_Y = pickle.load(open(os.path.join(opt.pretrain, 'train_Y.pkl'), 'rb')) logging.info("total training instance {}".format(len(train_Y))) train_loaders['main'], train_iters['main'] = makeDatasetWithoutUnknown(train_X, train_Y, relation_vocab, True, my_collate) unk_loaders['main'], unk_iters['main'] = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio) for other in other_dir: domains.append(other) other_X = pickle.load(open(os.path.join(opt.pretrain, 'other_{}_X.pkl'.format(other)), 'rb')) other_Y = pickle.load(open(os.path.join(opt.pretrain, 'other_{}_Y.pkl'.format(other)), 'rb')) logging.info("other {} instance {}".format(other, len(other_Y))) train_loaders[other], train_iters[other] = makeDatasetWithoutUnknown(other_X, other_Y, relation_vocab, True, my_collate) unk_loaders[other], unk_iters[other] = makeDatasetUnknown(other_X, other_Y, relation_vocab, my_collate, opt.unk_ratio) opt.domains = domains F_s = None F_d = {} if opt.model.lower() == 'lstm': F_s = LSTMFeatureExtractor(word_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.dropout) for domain in opt.domains: F_d[domain] = LSTMFeatureExtractor(word_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.dropout) elif opt.model.lower() == 'cnn': F_s = CNNFeatureExtractor(word_vocab, postag_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) for domain in opt.domains: F_d[domain] = CNNFeatureExtractor(word_vocab, postag_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise RuntimeError('Unknown feature extractor {}'.format(opt.model)) if opt.model_high == 'capsule': C = capsule.CapsuleNet(2*opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab, et_num_vocab) elif opt.model_high == 'mlp': C = baseline.MLP(2*opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab, et_num_vocab) else: raise RuntimeError('Unknown model {}'.format(opt.model_high)) if opt.adv: D = DomainClassifier(1, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.domains), opt.loss, opt.dropout, True) if torch.cuda.is_available(): F_s, C = F_s.cuda(opt.gpu), C.cuda(opt.gpu) for f_d in F_d.values(): f_d = f_d.cuda(opt.gpu) if opt.adv: D = D.cuda(opt.gpu) iter_parameter = itertools.chain( *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [f.parameters() for f in F_d.values()])) optimizer = optim.Adam(iter_parameter, lr=opt.learning_rate) if opt.adv: optimizerD = optim.Adam(D.parameters(), lr=opt.learning_rate) best_acc = 0.0 logging.info("start training ...") for epoch in range(opt.max_epoch): F_s.train() C.train() for f in F_d.values(): f.train() if opt.adv: D.train() # domain accuracy correct, total = defaultdict(int), defaultdict(int) # D accuracy if opt.adv: d_correct, d_total = 0, 0 # conceptually view 1 epoch as 1 epoch of the main domain num_iter = len(train_loaders['main']) for i in tqdm(range(num_iter)): if opt.adv: # D iterations my_utils.freeze_net(F_s) map(my_utils.freeze_net, F_d.values()) my_utils.freeze_net(C) my_utils.unfreeze_net(D) if opt.tune_wordemb == False: my_utils.freeze_net(F_s.word_emb) for f_d in F_d.values(): my_utils.freeze_net(f_d.word_emb) # WGAN n_critic trick since D trains slower n_critic = opt.n_critic if opt.wgan_trick: if opt.n_critic > 0 and ((epoch == 0 and i < 25) or i % 500 == 0): n_critic = 100 for _ in range(n_critic): D.zero_grad() # train on both labeled and unlabeled domains for domain in opt.domains: # targets not used x2, x1, _ = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain], train_iters[domain]) d_targets = my_utils.get_domain_label(opt.loss, domain, len(x2[1])) shared_feat = F_s(x2, x1) d_outputs = D(shared_feat) # D accuracy _, pred = torch.max(d_outputs, 1) d_total += len(x2[1]) if opt.loss.lower() == 'l2': _, tgt_indices = torch.max(d_targets, 1) d_correct += (pred == tgt_indices).sum().data.item() l_d = functional.mse_loss(d_outputs, d_targets) l_d.backward() else: d_correct += (pred == d_targets).sum().data.item() l_d = functional.nll_loss(d_outputs, d_targets) l_d.backward() optimizerD.step() # F&C iteration my_utils.unfreeze_net(F_s) map(my_utils.unfreeze_net, F_d.values()) my_utils.unfreeze_net(C) if opt.adv: my_utils.freeze_net(D) if opt.tune_wordemb == False: my_utils.freeze_net(F_s.word_emb) for f_d in F_d.values(): my_utils.freeze_net(f_d.word_emb) F_s.zero_grad() for f_d in F_d.values(): f_d.zero_grad() C.zero_grad() for domain in opt.domains: x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain], train_iters[domain]) shared_feat = F_s(x2, x1) domain_feat = F_d[domain](x2, x1) features = torch.cat((shared_feat, domain_feat), dim=1) if opt.model_high == 'capsule': c_outputs, x_recon = C.forward(features, x2, x1, targets) l_c = C.loss(targets, c_outputs, features, x2, x1, x_recon, opt.lam_recon) elif opt.model_high == 'mlp': c_outputs = C.forward(features, x2, x1) l_c = C.loss(targets, c_outputs) else: raise RuntimeError('Unknown model {}'.format(opt.model_high)) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().data.item() # training with unknown x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(unk_loaders[domain], unk_iters[domain]) shared_feat = F_s(x2, x1) domain_feat = F_d[domain](x2, x1) features = torch.cat((shared_feat, domain_feat), dim=1) if opt.model_high == 'capsule': c_outputs, x_recon = C.forward(features, x2, x1, targets) l_c = C.loss(targets, c_outputs, features, x2, x1, x_recon, opt.lam_recon) elif opt.model_high == 'mlp': c_outputs = C.forward(features, x2, x1) l_c = C.loss(targets, c_outputs) else: raise RuntimeError('Unknown model {}'.format(opt.model_high)) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().data.item() if opt.adv: # update F with D gradients on all domains for domain in opt.domains: x2, x1, _ = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain], train_iters[domain]) shared_feat = F_s(x2, x1) d_outputs = D(shared_feat) if opt.loss.lower() == 'gr': d_targets = my_utils.get_domain_label(opt.loss, domain, len(x2[1])) l_d = functional.nll_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= -opt.lambd elif opt.loss.lower() == 'bs': d_targets = my_utils.get_random_domain_label(opt.loss, len(x2[1])) l_d = functional.kl_div(d_outputs, d_targets, size_average=False) if opt.lambd > 0: l_d *= opt.lambd elif opt.loss.lower() == 'l2': d_targets = my_utils.get_random_domain_label(opt.loss, len(x2[1])) l_d = functional.mse_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= opt.lambd l_d.backward() torch.nn.utils.clip_grad_norm_(iter_parameter, opt.grad_clip) optimizer.step() # regenerate unknown dataset after one epoch unk_loaders['main'], unk_iters['main'] = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio) for other in other_dir: unk_loaders[other], unk_iters[other] = makeDatasetUnknown(other_X, other_Y, relation_vocab, my_collate, opt.unk_ratio) logging.info('epoch {} end'.format(epoch)) if opt.adv and d_total > 0: logging.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total)) logging.info('Training accuracy:') logging.info('\t'.join(opt.domains)) logging.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.domains])) test_accuracy = evaluate(F_s, F_d['main'], C, test_loader, test_Other) logging.info('Test Accuracy: {}%'.format(test_accuracy)) if test_accuracy > best_acc: best_acc = test_accuracy torch.save(F_s.state_dict(), '{}/F_s.pth'.format(opt.output)) for d in opt.domains: torch.save(F_d[d].state_dict(), '{}/F_d_{}.pth'.format(opt.output, d)) torch.save(C.state_dict(), '{}/C.pth'.format(opt.output)) if opt.adv: torch.save(D.state_dict(), '{}/D.pth'.format(opt.output)) pickle.dump(test_Other, open(os.path.join(opt.output, 'results.pkl'), "wb"), True) logging.info('New best accuracy: {}'.format(best_acc)) logging.info("training completed")