def makeRelationDataset(re_X_positive, re_Y_positive, re_X_negative, re_Y_negative, ratio, b_shuffle, my_collate, batch_size): a = list(range(len(re_X_negative))) random.shuffle(a) indices = a[:int(len(re_X_negative) * ratio)] temp_X = [] temp_Y = [] for i in range(len(re_X_positive)): temp_X.append(re_X_positive[i]) temp_Y.append(re_Y_positive[i]) for i in range(len(indices)): temp_X.append(re_X_negative[indices[i]]) temp_Y.append(re_Y_negative[indices[i]]) data_set = my_utils.RelationDataset(temp_X, temp_Y) data_loader = DataLoader(data_set, batch_size, shuffle=b_shuffle, collate_fn=my_collate) it = iter(data_loader) return data_loader, it
def makeDatasetWithoutUnknown(test_X, test_Y, relation_vocab, b_shuffle, my_collate, batch_size): test_X_remove_unk = [] test_Y_remove_unk = [] for i in range(len(test_X)): x = test_X[i] y = test_Y[i] if y != relation_vocab.get_index("</unk>"): test_X_remove_unk.append(x) test_Y_remove_unk.append(y) test_set = my_utils.RelationDataset(test_X_remove_unk, test_Y_remove_unk) test_loader = DataLoader(test_set, batch_size, shuffle=b_shuffle, collate_fn=my_collate) it = iter(test_loader) logging.info("instance after removing unknown, {}".format(len(test_Y_remove_unk))) return test_loader, it
def makeDatasetUnknown(test_X, test_Y, relation_vocab, my_collate, ratio, batch_size): test_X_remove_unk = [] test_Y_remove_unk = [] for i in range(len(test_X)): x = test_X[i] y = test_Y[i] if y == relation_vocab.get_index("</unk>"): test_X_remove_unk.append(x) test_Y_remove_unk.append(y) test_set = my_utils.RelationDataset(test_X_remove_unk, test_Y_remove_unk) test_loader = DataLoader(test_set, batch_size, shuffle=False, sampler=randomSampler(test_Y_remove_unk, ratio), collate_fn=my_collate) it = iter(test_loader) return test_loader, it
def makeDatasetForEachClass(train_X, train_Y, relation_vocab, my_collate): train_X_classified = {} train_Y_classified = {} for i in range(len(train_X)): x = train_X[i] y = train_Y[i] class_name = relation_vocab.lookup_id2str(y) if class_name in train_X_classified.keys(): train_X_classified[class_name].append(x) train_Y_classified[class_name].append(y) else: train_X_classified[class_name] = [x] train_Y_classified[class_name] = [y] train_sets = [] # each class corresponds to a set, loader, sample train_loaders = [] train_samples = [] train_numbers = [] train_iters = [] for class_name in train_Y_classified: x = train_X_classified[class_name] y = train_Y_classified[class_name] train_numbers.append((class_name, len(y))) train_set = my_utils.RelationDataset(x, y) train_sets.append(train_set) train_sampler = torch.utils.data.sampler.RandomSampler(train_set) train_samples.append(train_sampler) train_loader = DataLoader(train_set, 1, shuffle=False, sampler=train_sampler, collate_fn=my_collate) train_loaders.append(train_loader) train_iter = iter(train_loader) train_iters.append(train_iter) return train_loaders, train_iters, train_numbers
def train(data, dir): my_collate = my_utils.sorted_collate1 train_loader, train_iter = makeDatasetWithoutUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], True, my_collate, data.HP_batch_size) num_iter = len(train_loader) unk_loader, unk_iter = makeDatasetUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], my_collate, data.unk_ratio, data.HP_batch_size) test_loader = DataLoader(my_utils.RelationDataset(data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) wordseq = WordSequence(data, True, False, False) model = ClassifyModel(data) if torch.cuda.is_available(): model = model.cuda(data.HP_gpu) if opt.self_adv == 'grad': wordseq_adv = WordSequence(data, True, False, False) elif opt.self_adv == 'label': wordseq_adv = WordSequence(data, True, False, False) model_adv = ClassifyModel(data) if torch.cuda.is_available(): model_adv = model_adv.cuda(data.HP_gpu) else: wordseq_adv = None if opt.self_adv == 'grad': iter_parameter = itertools.chain( *map(list, [wordseq.parameters(), wordseq_adv.parameters(), model.parameters()])) optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) elif opt.self_adv == 'label': iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()])) optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) iter_parameter = itertools.chain(*map(list, [wordseq_adv.parameters(), model_adv.parameters()])) optimizer_adv = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) else: iter_parameter = itertools.chain(*map(list, [wordseq.parameters(), model.parameters()])) optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) if data.tune_wordemb == False: my_utils.freeze_net(wordseq.wordrep.word_embedding) if opt.self_adv != 'no': my_utils.freeze_net(wordseq_adv.wordrep.word_embedding) best_acc = 0.0 logging.info("start training ...") for epoch in range(data.max_epoch): wordseq.train() wordseq.zero_grad() model.train() model.zero_grad() if opt.self_adv == 'grad': wordseq_adv.train() wordseq_adv.zero_grad() elif opt.self_adv == 'label': wordseq_adv.train() wordseq_adv.zero_grad() model_adv.train() model_adv.zero_grad() correct, total = 0, 0 for i in range(num_iter): [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(train_loader, train_iter) if opt.self_adv == 'grad': hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() my_utils.reverse_grad(wordseq_adv) optimizer.step() wordseq.zero_grad() wordseq_adv.zero_grad() model.zero_grad() elif opt.self_adv == 'label' : wordseq.unfreeze_net() wordseq_adv.freeze_net() hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() optimizer.step() wordseq.zero_grad() wordseq_adv.zero_grad() model.zero_grad() wordseq.freeze_net() wordseq_adv.unfreeze_net() hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = wordseq_adv.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) loss_adv, _ = model_adv.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets_permute) loss_adv.backward() optimizer_adv.step() wordseq.zero_grad() wordseq_adv.zero_grad() model_adv.zero_grad() else: hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = None loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() optimizer.step() wordseq.zero_grad() model.zero_grad() total += targets.size(0) correct += (pred == targets).sum().item() [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1(unk_loader, unk_iter) hidden = wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = None loss, pred = model.neg_log_likelihood_loss(hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() optimizer.step() wordseq.zero_grad() model.zero_grad() unk_loader, unk_iter = makeDatasetUnknown(data.re_train_X, data.re_train_Y, data.re_feature_alphabets[data.re_feature_name2id['[RELATION]']], my_collate, data.unk_ratio, data.HP_batch_size) logging.info('epoch {} end'.format(epoch)) logging.info('Train Accuracy: {}%'.format(100.0 * correct / total)) test_accuracy = evaluate(wordseq, model, test_loader) # test_accuracy = evaluate(m, test_loader) logging.info('Test Accuracy: {}%'.format(test_accuracy)) if test_accuracy > best_acc: best_acc = test_accuracy torch.save(wordseq.state_dict(), os.path.join(dir, 'wordseq.pkl')) torch.save(model.state_dict(), '{}/model.pkl'.format(dir)) if opt.self_adv == 'grad': torch.save(wordseq_adv.state_dict(), os.path.join(dir, 'wordseq_adv.pkl')) elif opt.self_adv == 'label': torch.save(wordseq_adv.state_dict(), os.path.join(dir, 'wordseq_adv.pkl')) torch.save(model_adv.state_dict(), os.path.join(dir, 'model_adv.pkl')) logging.info('New best accuracy: {}'.format(best_acc)) logging.info("training completed")
def joint_train(data, old_data, opt): if not os.path.exists(opt.output): os.makedirs(opt.output) if opt.pretrained_model_dir != 'None': seq_model = SeqModel(data) if opt.test_in_cpu: seq_model.load_state_dict( torch.load(os.path.join(opt.pretrained_model_dir, 'ner_model.pkl'), map_location='cpu')) else: seq_model.load_state_dict( torch.load( os.path.join(opt.pretrained_model_dir, 'ner_model.pkl'))) if (data.label_alphabet_size != seq_model.crf.tagset_size)\ or (data.HP_hidden_dim != seq_model.hidden2tag.weight.size(1)): raise RuntimeError("ner_model not compatible") seq_wordseq = WordSequence(data, False, True, True, True) if ((data.word_emb_dim != seq_wordseq.wordrep.word_embedding.embedding_dim)\ or (data.char_emb_dim != seq_wordseq.wordrep.char_feature.char_embeddings.embedding_dim)\ or (data.feature_emb_dims[0] != seq_wordseq.wordrep.feature_embedding_dims[0])\ or (data.feature_emb_dims[1] != seq_wordseq.wordrep.feature_embedding_dims[1])): raise RuntimeError("ner_wordseq not compatible") old_seq_wordseq = WordSequence(old_data, False, True, True, True) if opt.test_in_cpu: old_seq_wordseq.load_state_dict( torch.load(os.path.join(opt.pretrained_model_dir, 'ner_wordseq.pkl'), map_location='cpu')) else: old_seq_wordseq.load_state_dict( torch.load( os.path.join(opt.pretrained_model_dir, 'ner_wordseq.pkl'))) # sd = old_seq_wordseq.lstm.state_dict() for old, new in zip(old_seq_wordseq.lstm.parameters(), seq_wordseq.lstm.parameters()): new.data.copy_(old) vocab_size = old_seq_wordseq.wordrep.word_embedding.num_embeddings seq_wordseq.wordrep.word_embedding.weight.data[ 0: vocab_size, :] = old_seq_wordseq.wordrep.word_embedding.weight.data[ 0:vocab_size, :] vocab_size = old_seq_wordseq.wordrep.char_feature.char_embeddings.num_embeddings seq_wordseq.wordrep.char_feature.char_embeddings.weight.data[ 0: vocab_size, :] = old_seq_wordseq.wordrep.char_feature.char_embeddings.weight.data[ 0:vocab_size, :] for i, feature_embedding in enumerate( old_seq_wordseq.wordrep.feature_embeddings): vocab_size = feature_embedding.num_embeddings seq_wordseq.wordrep.feature_embeddings[i].weight.data[ 0:vocab_size, :] = feature_embedding.weight.data[ 0:vocab_size, :] # for word in data.word_alphabet.iteritems(): # # old_seq_wordseq.wordrep.word_embedding.weight.data data.word_alphabet.get_index(word) # classify_wordseq = WordSequence(data, True, False, True, False) if ((data.word_emb_dim != classify_wordseq.wordrep.word_embedding.embedding_dim)\ or (data.re_feature_emb_dims[data.re_feature_name2id['[POSITION]']] != classify_wordseq.wordrep.position1_emb.embedding_dim)\ or (data.feature_emb_dims[1] != classify_wordseq.wordrep.feature_embedding_dims[0])): raise RuntimeError("re_wordseq not compatible") old_classify_wordseq = WordSequence(old_data, True, False, True, False) if opt.test_in_cpu: old_classify_wordseq.load_state_dict( torch.load(os.path.join(opt.pretrained_model_dir, 're_wordseq.pkl'), map_location='cpu')) else: old_classify_wordseq.load_state_dict( torch.load( os.path.join(opt.pretrained_model_dir, 're_wordseq.pkl'))) for old, new in zip(old_classify_wordseq.lstm.parameters(), classify_wordseq.lstm.parameters()): new.data.copy_(old) vocab_size = old_classify_wordseq.wordrep.word_embedding.num_embeddings classify_wordseq.wordrep.word_embedding.weight.data[ 0: vocab_size, :] = old_classify_wordseq.wordrep.word_embedding.weight.data[ 0:vocab_size, :] vocab_size = old_classify_wordseq.wordrep.position1_emb.num_embeddings classify_wordseq.wordrep.position1_emb.weight.data[ 0: vocab_size, :] = old_classify_wordseq.wordrep.position1_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_wordseq.wordrep.position2_emb.num_embeddings classify_wordseq.wordrep.position2_emb.weight.data[ 0: vocab_size, :] = old_classify_wordseq.wordrep.position2_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_wordseq.wordrep.feature_embeddings[ 0].num_embeddings classify_wordseq.wordrep.feature_embeddings[0].weight.data[ 0:vocab_size, :] = old_classify_wordseq.wordrep.feature_embeddings[ 0].weight.data[0:vocab_size, :] classify_model = ClassifyModel(data) old_classify_model = ClassifyModel(old_data) if opt.test_in_cpu: old_classify_model.load_state_dict( torch.load(os.path.join(opt.pretrained_model_dir, 're_model.pkl'), map_location='cpu')) else: old_classify_model.load_state_dict( torch.load( os.path.join(opt.pretrained_model_dir, 're_model.pkl'))) if (data.re_feature_alphabet_sizes[data.re_feature_name2id['[RELATION]']] != old_classify_model.linear.weight.size(0) or (data.re_feature_emb_dims[data.re_feature_name2id['[ENTITY_TYPE]']] != old_classify_model.entity_type_emb.embedding_dim) \ or (data.re_feature_emb_dims[ data.re_feature_name2id['[ENTITY]']] != old_classify_model.entity_emb.embedding_dim) or (data.re_feature_emb_dims[ data.re_feature_name2id['[TOKEN_NUM]']] != old_classify_model.tok_num_betw_emb.embedding_dim) \ or (data.re_feature_emb_dims[ data.re_feature_name2id['[ENTITY_NUM]']] != old_classify_model.et_num_emb.embedding_dim) \ ): raise RuntimeError("re_model not compatible") vocab_size = old_classify_model.entity_type_emb.num_embeddings classify_model.entity_type_emb.weight.data[ 0:vocab_size, :] = old_classify_model.entity_type_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_model.entity_emb.num_embeddings classify_model.entity_emb.weight.data[ 0:vocab_size, :] = old_classify_model.entity_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_model.tok_num_betw_emb.num_embeddings classify_model.tok_num_betw_emb.weight.data[ 0:vocab_size, :] = old_classify_model.tok_num_betw_emb.weight.data[ 0:vocab_size, :] vocab_size = old_classify_model.et_num_emb.num_embeddings classify_model.et_num_emb.weight.data[ 0:vocab_size, :] = old_classify_model.et_num_emb.weight.data[ 0:vocab_size, :] else: seq_model = SeqModel(data) seq_wordseq = WordSequence(data, False, True, True, True) classify_wordseq = WordSequence(data, True, False, True, False) classify_model = ClassifyModel(data) iter_parameter = itertools.chain( *map(list, [seq_wordseq.parameters(), seq_model.parameters()])) seq_optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) iter_parameter = itertools.chain(*map( list, [classify_wordseq.parameters(), classify_model.parameters()])) classify_optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) if data.tune_wordemb == False: my_utils.freeze_net(seq_wordseq.wordrep.word_embedding) my_utils.freeze_net(classify_wordseq.wordrep.word_embedding) re_X_positive = [] re_Y_positive = [] re_X_negative = [] re_Y_negative = [] relation_vocab = data.re_feature_alphabets[ data.re_feature_name2id['[RELATION]']] my_collate = my_utils.sorted_collate1 for i in range(len(data.re_train_X)): x = data.re_train_X[i] y = data.re_train_Y[i] if y != relation_vocab.get_index("</unk>"): re_X_positive.append(x) re_Y_positive.append(y) else: re_X_negative.append(x) re_Y_negative.append(y) re_dev_loader = DataLoader(my_utils.RelationDataset( data.re_dev_X, data.re_dev_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) # re_test_loader = DataLoader(my_utils.RelationDataset(data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) best_ner_score = -1 best_re_score = -1 count_performance_not_grow = 0 for idx in range(data.HP_iteration): epoch_start = time.time() seq_wordseq.train() seq_wordseq.zero_grad() seq_model.train() seq_model.zero_grad() classify_wordseq.train() classify_wordseq.zero_grad() classify_model.train() classify_model.zero_grad() batch_size = data.HP_batch_size random.shuffle(data.train_Ids) ner_train_num = len(data.train_Ids) ner_total_batch = ner_train_num // batch_size + 1 re_train_loader, re_train_iter = makeRelationDataset( re_X_positive, re_Y_positive, re_X_negative, re_Y_negative, data.unk_ratio, True, my_collate, data.HP_batch_size) re_total_batch = len(re_train_loader) total_batch = max(ner_total_batch, re_total_batch) min_batch = min(ner_total_batch, re_total_batch) for batch_id in range(total_batch): if batch_id < ner_total_batch: start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > ner_train_num: end = ner_train_num instance = data.train_Ids[start:end] batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \ batch_permute_label = batchify_with_label(instance, data.HP_gpu) hidden = seq_wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) hidden_adv = None loss, tag_seq = seq_model.neg_log_likelihood_loss( hidden, hidden_adv, batch_label, mask) loss.backward() seq_optimizer.step() seq_wordseq.zero_grad() seq_model.zero_grad() if batch_id < re_total_batch: [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1( re_train_loader, re_train_iter) hidden = classify_wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = None loss, pred = classify_model.neg_log_likelihood_loss( hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() classify_optimizer.step() classify_wordseq.zero_grad() classify_model.zero_grad() epoch_finish = time.time() logging.info("epoch: %s training finished. Time: %.2fs" % (idx, epoch_finish - epoch_start)) _, _, _, _, ner_score, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "dev") logging.info("ner evaluate: f: %.4f" % (ner_score)) if ner_score > best_ner_score: logging.info("new best score: ner: %.4f" % (ner_score)) best_ner_score = ner_score torch.save(seq_wordseq.state_dict(), os.path.join(opt.output, 'ner_wordseq.pkl')) torch.save(seq_model.state_dict(), os.path.join(opt.output, 'ner_model.pkl')) count_performance_not_grow = 0 # _, _, _, _, test_ner_score, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "test") # logging.info("ner evaluate on test: f: %.4f" % (test_ner_score)) else: count_performance_not_grow += 1 re_score = relation_extraction.evaluate(classify_wordseq, classify_model, re_dev_loader) logging.info("re evaluate: f: %.4f" % (re_score)) if re_score > best_re_score: logging.info("new best score: re: %.4f" % (re_score)) best_re_score = re_score torch.save(classify_wordseq.state_dict(), os.path.join(opt.output, 're_wordseq.pkl')) torch.save(classify_model.state_dict(), os.path.join(opt.output, 're_model.pkl')) count_performance_not_grow = 0 # test_re_score = relation_extraction.evaluate(classify_wordseq, classify_model, re_test_loader) # logging.info("re evaluate on test: f: %.4f" % (test_re_score)) else: count_performance_not_grow += 1 if count_performance_not_grow > 2 * data.patience: logging.info("early stop") break logging.info("train finished")
def train(): logging.info("loading ... vocab") word_vocab = pickle.load( open(os.path.join(opt.pretrain, 'word_vocab.pkl'), 'rb')) postag_vocab = pickle.load( open(os.path.join(opt.pretrain, 'postag_vocab.pkl'), 'rb')) relation_vocab = pickle.load( open(os.path.join(opt.pretrain, 'relation_vocab.pkl'), 'rb')) entity_type_vocab = pickle.load( open(os.path.join(opt.pretrain, 'entity_type_vocab.pkl'), 'rb')) entity_vocab = pickle.load( open(os.path.join(opt.pretrain, 'entity_vocab.pkl'), 'rb')) position_vocab1 = pickle.load( open(os.path.join(opt.pretrain, 'position_vocab1.pkl'), 'rb')) position_vocab2 = pickle.load( open(os.path.join(opt.pretrain, 'position_vocab2.pkl'), 'rb')) tok_num_betw_vocab = pickle.load( open(os.path.join(opt.pretrain, 'tok_num_betw_vocab.pkl'), 'rb')) et_num_vocab = pickle.load( open(os.path.join(opt.pretrain, 'et_num_vocab.pkl'), 'rb')) # One relation instance is composed of X (a pair of entities and their context), Y (relation label). train_X = pickle.load(open(os.path.join(opt.pretrain, 'train_X.pkl'), 'rb')) train_Y = pickle.load(open(os.path.join(opt.pretrain, 'train_Y.pkl'), 'rb')) logging.info("total training instance {}".format(len(train_Y))) my_collate = my_utils.sorted_collate if opt.model == 'lstm' else my_utils.unsorted_collate if opt.strategy == 'all': train_loader = DataLoader(my_utils.RelationDataset(train_X, train_Y), opt.batch_size, shuffle=True, collate_fn=my_collate) train_iter = iter(train_loader) num_iter = len(train_loader) elif opt.strategy == 'no-unk': train_loader, train_iter = makeDatasetWithoutUnknown( train_X, train_Y, relation_vocab, True, my_collate) num_iter = len(train_loader) elif opt.strategy == 'balance': train_loaders, train_iters, train_numbers = makeDatasetForEachClass( train_X, train_Y, relation_vocab, my_collate) for t in train_numbers: logging.info(t) # use the median number of instance number of all classes except unknown num_iter = int( np.median( np.array([ num for class_name, num in train_numbers if class_name != relation_vocab.unk_tok ]))) elif opt.strategy == 'part-unk': train_loader, train_iter = makeDatasetWithoutUnknown( train_X, train_Y, relation_vocab, True, my_collate) num_iter = len(train_loader) unk_loader, unk_iter = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio) else: raise RuntimeError("unsupport training strategy") test_X = pickle.load(open(os.path.join(opt.pretrain, 'test_X.pkl'), 'rb')) test_Y = pickle.load(open(os.path.join(opt.pretrain, 'test_Y.pkl'), 'rb')) test_Other = pickle.load( open(os.path.join(opt.pretrain, 'test_Other.pkl'), 'rb')) logging.info("total test instance {}".format(len(test_Y))) #test_loader, _ = makeDatasetWithoutUnknown(test_X, test_Y, relation_vocab, False, my_collate) test_loader = DataLoader(my_utils.RelationDataset(test_X, test_Y), opt.batch_size, shuffle=False, collate_fn=my_collate) # drop_last=True if opt.model.lower() == 'lstm': feature_extractor = LSTMFeatureExtractor(word_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.dropout) elif opt.model.lower() == 'cnn': feature_extractor = CNNFeatureExtractor(word_vocab, postag_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise RuntimeError('Unknown feature extractor {}'.format(opt.model)) if torch.cuda.is_available(): feature_extractor = feature_extractor.cuda(opt.gpu) if opt.model_high == 'capsule': m = capsule.CapsuleNet(opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab, et_num_vocab) elif opt.model_high == 'capsule_em': m = capsule_em.CapsuleNet_EM(opt.shared_hidden_size, relation_vocab) elif opt.model_high == 'mlp': m = baseline.MLP(opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab, et_num_vocab) # m = baseline.SentimentClassifier(opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab, # et_num_vocab, # opt.F_layers, opt.shared_hidden_size, opt.dropout, opt.model_high_bn) else: raise RuntimeError('Unknown model {}'.format(opt.model_high)) if torch.cuda.is_available(): m = m.cuda(opt.gpu) iter_parameter = itertools.chain( *map(list, [feature_extractor.parameters(), m.parameters()])) optimizer = optim.Adam(iter_parameter, lr=opt.learning_rate) if opt.tune_wordemb == False: my_utils.freeze_layer(feature_extractor.word_emb) #feature_extractor.word_emb.weight.requires_grad = False best_acc = 0.0 logging.info("start training ...") for epoch in range(opt.max_epoch): feature_extractor.train() m.train() correct, total = 0, 0 for i in tqdm(range(num_iter)): if opt.strategy == 'all' or opt.strategy == 'no-unk' or opt.strategy == 'part-unk': x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch( train_loader, train_iter) elif opt.strategy == 'balance': x2, x1, targets = my_utils.endless_get_next_batch( train_loaders, train_iters) else: raise RuntimeError("unsupport training strategy") hidden_features = feature_extractor.forward(x2, x1) if opt.model_high == 'capsule': outputs, x_recon = m.forward(hidden_features, x2, x1, targets) loss = m.loss(targets, outputs, hidden_features, x2, x1, x_recon, opt.lam_recon) elif opt.model_high == 'mlp': outputs = m.forward(hidden_features, x2, x1) loss = m.loss(targets, outputs) else: raise RuntimeError('Unknown model {}'.format(opt.model_high)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(iter_parameter, opt.grad_clip) optimizer.step() total += targets.size(0) _, pred = torch.max(outputs, 1) correct += (pred == targets).sum().item() if opt.strategy == 'part-unk': x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch( unk_loader, unk_iter) hidden_features = feature_extractor.forward(x2, x1) if opt.model_high == 'capsule': outputs, x_recon = m.forward(hidden_features, x2, x1, targets) loss = m.loss(targets, outputs, hidden_features, x2, x1, x_recon, opt.lam_recon) elif opt.model_high == 'mlp': outputs = m.forward(hidden_features, x2, x1) loss = m.loss(targets, outputs) else: raise RuntimeError('Unknown model {}'.format( opt.model_high)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(iter_parameter, opt.grad_clip) optimizer.step() if opt.strategy == 'part-unk': unk_loader, unk_iter = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio) logging.info('epoch {} end'.format(epoch)) logging.info('Train Accuracy: {}%'.format(100.0 * correct / total)) test_accuracy = evaluate(feature_extractor, m, test_loader, test_Other) # test_accuracy = evaluate(m, test_loader) logging.info('Test Accuracy: {}%'.format(test_accuracy)) if test_accuracy > best_acc: best_acc = test_accuracy torch.save(feature_extractor.state_dict(), '{}/feature_extractor.pth'.format(opt.output)) torch.save(m.state_dict(), '{}/model.pth'.format(opt.output)) pickle.dump(test_Other, open(os.path.join(opt.output, 'results.pkl'), "wb"), True) logging.info('New best accuracy: {}'.format(best_acc)) logging.info("training completed")
def train(data, ner_dir, re_dir): task_num = 2 init_alpha = 1.0 / task_num if opt.hidden_num <= 1: step_alpha = 0 else: step_alpha = (1.0 - init_alpha) / (opt.hidden_num - 1) ner_wordrep = WordRep(data, False, True, True, data.use_char) ner_hiddenlist = [] for i in range(opt.hidden_num): if i == 0: input_size = data.word_emb_dim+data.HP_char_hidden_dim+data.feature_emb_dims[data.feature_name2id['[Cap]']]+ \ data.feature_emb_dims[data.feature_name2id['[POS]']] output_size = data.HP_hidden_dim else: input_size = data.HP_hidden_dim output_size = data.HP_hidden_dim temp = HiddenLayer(data, input_size, output_size) ner_hiddenlist.append(temp) seq_model = SeqModel(data) re_wordrep = WordRep(data, True, False, True, False) re_hiddenlist = [] for i in range(opt.hidden_num): if i == 0: input_size = data.word_emb_dim + data.feature_emb_dims[data.feature_name2id['[POS]']]+\ 2*data.re_feature_emb_dims[data.re_feature_name2id['[POSITION]']] output_size = data.HP_hidden_dim else: input_size = data.HP_hidden_dim output_size = data.HP_hidden_dim temp = HiddenLayer(data, input_size, output_size) re_hiddenlist.append(temp) classify_model = ClassifyModel(data) iter_parameter = itertools.chain( *map(list, [ner_wordrep.parameters(), seq_model.parameters()] + [f.parameters() for f in ner_hiddenlist])) ner_optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) iter_parameter = itertools.chain( *map(list, [re_wordrep.parameters(), classify_model.parameters()] + [f.parameters() for f in re_hiddenlist])) re_optimizer = optim.Adam(iter_parameter, lr=data.HP_lr, weight_decay=data.HP_l2) if data.tune_wordemb == False: my_utils.freeze_net(ner_wordrep.word_embedding) my_utils.freeze_net(re_wordrep.word_embedding) re_X_positive = [] re_Y_positive = [] re_X_negative = [] re_Y_negative = [] relation_vocab = data.re_feature_alphabets[ data.re_feature_name2id['[RELATION]']] my_collate = my_utils.sorted_collate1 for i in range(len(data.re_train_X)): x = data.re_train_X[i] y = data.re_train_Y[i] if y != relation_vocab.get_index("</unk>"): re_X_positive.append(x) re_Y_positive.append(y) else: re_X_negative.append(x) re_Y_negative.append(y) re_test_loader = DataLoader(my_utils.RelationDataset( data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) best_ner_score = -1 best_re_score = -1 for idx in range(data.HP_iteration): epoch_start = time.time() ner_wordrep.train() ner_wordrep.zero_grad() for hidden_layer in ner_hiddenlist: hidden_layer.train() hidden_layer.zero_grad() seq_model.train() seq_model.zero_grad() re_wordrep.train() re_wordrep.zero_grad() for hidden_layer in re_hiddenlist: hidden_layer.train() hidden_layer.zero_grad() classify_model.train() classify_model.zero_grad() batch_size = data.HP_batch_size random.shuffle(data.train_Ids) ner_train_num = len(data.train_Ids) ner_total_batch = ner_train_num // batch_size + 1 re_train_loader, re_train_iter = makeRelationDataset( re_X_positive, re_Y_positive, re_X_negative, re_Y_negative, data.unk_ratio, True, my_collate, data.HP_batch_size) re_total_batch = len(re_train_loader) total_batch = max(ner_total_batch, re_total_batch) min_batch = min(ner_total_batch, re_total_batch) for batch_id in range(total_batch): if batch_id < min_batch: start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > ner_train_num: end = ner_train_num instance = data.train_Ids[start:end] ner_batch_word, ner_batch_features, ner_batch_wordlen, ner_batch_wordrecover, ner_batch_char, ner_batch_charlen, \ ner_batch_charrecover, ner_batch_label, ner_mask, ner_batch_permute_label = batchify_with_label(instance, data.HP_gpu) [re_batch_word, re_batch_features, re_batch_wordlen, re_batch_wordrecover, re_batch_char, re_batch_charlen, re_batch_charrecover, re_position1_seq_tensor, re_position2_seq_tensor, re_e1_token, re_e1_length, re_e2_token, re_e2_length, re_e1_type, re_e2_type, re_tok_num_betw, re_et_num], [re_targets, re_targets_permute] = \ my_utils.endless_get_next_batch_without_rebatch1(re_train_loader, re_train_iter) if ner_batch_word.size(0) != re_batch_word.size(0): continue # if batch size is not equal, we ignore such batch ner_word_rep = ner_wordrep.forward( ner_batch_word, ner_batch_features, ner_batch_wordlen, ner_batch_char, ner_batch_charlen, ner_batch_charrecover, None, None) re_word_rep = re_wordrep.forward( re_batch_word, re_batch_features, re_batch_wordlen, re_batch_char, re_batch_charlen, re_batch_charrecover, re_position1_seq_tensor, re_position2_seq_tensor) alpha = init_alpha ner_hidden = ner_word_rep re_hidden = re_word_rep for i in range(opt.hidden_num): if alpha > 0.99: use_attn = False ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward( ner_hidden, ner_batch_wordlen, use_attn) re_lstm_out, re_att_out = re_hiddenlist[i].forward( re_hidden, re_batch_wordlen, use_attn) ner_hidden = ner_lstm_out re_hidden = re_lstm_out else: use_attn = True ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward( ner_hidden, ner_batch_wordlen, use_attn) re_lstm_out, re_att_out = re_hiddenlist[i].forward( re_hidden, re_batch_wordlen, use_attn) ner_hidden = alpha * ner_lstm_out + ( 1 - alpha) * re_att_out.unsqueeze(1) re_hidden = alpha * re_lstm_out + ( 1 - alpha) * ner_att_out.unsqueeze(1) alpha += step_alpha ner_loss, ner_tag_seq = seq_model.neg_log_likelihood_loss( ner_hidden, ner_batch_label, ner_mask) re_loss, re_pred = classify_model.neg_log_likelihood_loss( re_hidden, re_batch_wordlen, re_e1_token, re_e1_length, re_e2_token, re_e2_length, re_e1_type, re_e2_type, re_tok_num_betw, re_et_num, re_targets) ner_loss.backward(retain_graph=True) re_loss.backward() ner_optimizer.step() re_optimizer.step() ner_wordrep.zero_grad() for hidden_layer in ner_hiddenlist: hidden_layer.zero_grad() seq_model.zero_grad() re_wordrep.zero_grad() for hidden_layer in re_hiddenlist: hidden_layer.zero_grad() classify_model.zero_grad() else: if batch_id < ner_total_batch: start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > ner_train_num: end = ner_train_num instance = data.train_Ids[start:end] batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \ batch_permute_label = batchify_with_label(instance, data.HP_gpu) ner_word_rep = ner_wordrep.forward( batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) ner_hidden = ner_word_rep for i in range(opt.hidden_num): ner_lstm_out, ner_att_out = ner_hiddenlist[i].forward( ner_hidden, batch_wordlen, False) ner_hidden = ner_lstm_out loss, tag_seq = seq_model.neg_log_likelihood_loss( ner_hidden, batch_label, mask) loss.backward() ner_optimizer.step() ner_wordrep.zero_grad() for hidden_layer in ner_hiddenlist: hidden_layer.zero_grad() seq_model.zero_grad() if batch_id < re_total_batch: [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1( re_train_loader, re_train_iter) re_word_rep = re_wordrep.forward( batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) re_hidden = re_word_rep for i in range(opt.hidden_num): re_lstm_out, re_att_out = re_hiddenlist[i].forward( re_hidden, batch_wordlen, False) re_hidden = re_lstm_out loss, pred = classify_model.neg_log_likelihood_loss( re_hidden, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() re_optimizer.step() re_wordrep.zero_grad() for hidden_layer in re_hiddenlist: hidden_layer.zero_grad() classify_model.zero_grad() epoch_finish = time.time() print("epoch: %s training finished. Time: %.2fs" % (idx, epoch_finish - epoch_start)) ner_score = ner_evaluate(data, ner_wordrep, ner_hiddenlist, seq_model, "test") print("ner evaluate: f: %.4f" % (ner_score)) re_score = re_evaluate(re_wordrep, re_hiddenlist, classify_model, re_test_loader) print("re evaluate: f: %.4f" % (re_score)) if ner_score + re_score > best_ner_score + best_re_score: print("new best score: ner: %.4f , re: %.4f" % (ner_score, re_score)) best_ner_score = ner_score best_re_score = re_score torch.save(ner_wordrep.state_dict(), os.path.join(ner_dir, 'wordrep.pkl')) for i, hidden_layer in enumerate(ner_hiddenlist): torch.save(hidden_layer.state_dict(), os.path.join(ner_dir, 'hidden_{}.pkl'.format(i))) torch.save(seq_model.state_dict(), os.path.join(ner_dir, 'model.pkl')) torch.save(re_wordrep.state_dict(), os.path.join(re_dir, 'wordrep.pkl')) for i, hidden_layer in enumerate(re_hiddenlist): torch.save(hidden_layer.state_dict(), os.path.join(re_dir, 'hidden_{}.pkl'.format(i))) torch.save(classify_model.state_dict(), os.path.join(re_dir, 'model.pkl'))
def pipeline(data, ner_dir, re_dir): seq_model = SeqModel(data) seq_wordseq = WordSequence(data, False, True, True, data.use_char) classify_wordseq = WordSequence(data, True, False, True, False) classify_model = ClassifyModel(data) if torch.cuda.is_available(): classify_model = classify_model.cuda(data.HP_gpu) iter_parameter = itertools.chain( *map(list, [seq_wordseq.parameters(), seq_model.parameters()])) seq_optimizer = optim.Adam(iter_parameter, lr=opt.ner_lr, weight_decay=data.HP_l2) iter_parameter = itertools.chain(*map( list, [classify_wordseq.parameters(), classify_model.parameters()])) classify_optimizer = optim.Adam(iter_parameter, lr=opt.re_lr, weight_decay=data.HP_l2) if data.tune_wordemb == False: my_utils.freeze_net(seq_wordseq.wordrep.word_embedding) my_utils.freeze_net(classify_wordseq.wordrep.word_embedding) re_X_positive = [] re_Y_positive = [] re_X_negative = [] re_Y_negative = [] relation_vocab = data.re_feature_alphabets[ data.re_feature_name2id['[RELATION]']] my_collate = my_utils.sorted_collate1 for i in range(len(data.re_train_X)): x = data.re_train_X[i] y = data.re_train_Y[i] if y != relation_vocab.get_index("</unk>"): re_X_positive.append(x) re_Y_positive.append(y) else: re_X_negative.append(x) re_Y_negative.append(y) re_test_loader = DataLoader(my_utils.RelationDataset( data.re_test_X, data.re_test_Y), data.HP_batch_size, shuffle=False, collate_fn=my_collate) best_ner_score = -1 best_re_score = -1 for idx in range(data.HP_iteration): epoch_start = time.time() seq_wordseq.train() seq_wordseq.zero_grad() seq_model.train() seq_model.zero_grad() classify_wordseq.train() classify_wordseq.zero_grad() classify_model.train() classify_model.zero_grad() batch_size = data.HP_batch_size random.shuffle(data.train_Ids) ner_train_num = len(data.train_Ids) ner_total_batch = ner_train_num // batch_size + 1 re_train_loader, re_train_iter = makeRelationDataset( re_X_positive, re_Y_positive, re_X_negative, re_Y_negative, data.unk_ratio, True, my_collate, data.HP_batch_size) re_total_batch = len(re_train_loader) total_batch = max(ner_total_batch, re_total_batch) min_batch = min(ner_total_batch, re_total_batch) for batch_id in range(total_batch): if batch_id < ner_total_batch: start = batch_id * batch_size end = (batch_id + 1) * batch_size if end > ner_train_num: end = ner_train_num instance = data.train_Ids[start:end] batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask, \ batch_permute_label = batchify_with_label(instance, data.HP_gpu) hidden = seq_wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, None, None) hidden_adv = None loss, tag_seq = seq_model.neg_log_likelihood_loss( hidden, hidden_adv, batch_label, mask) loss.backward() seq_optimizer.step() seq_wordseq.zero_grad() seq_model.zero_grad() if batch_id < re_total_batch: [batch_word, batch_features, batch_wordlen, batch_wordrecover, \ batch_char, batch_charlen, batch_charrecover, \ position1_seq_tensor, position2_seq_tensor, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, \ tok_num_betw, et_num], [targets, targets_permute] = my_utils.endless_get_next_batch_without_rebatch1( re_train_loader, re_train_iter) hidden = classify_wordseq.forward(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, position1_seq_tensor, position2_seq_tensor) hidden_adv = None loss, pred = classify_model.neg_log_likelihood_loss( hidden, hidden_adv, batch_wordlen, e1_token, e1_length, e2_token, e2_length, e1_type, e2_type, tok_num_betw, et_num, targets) loss.backward() classify_optimizer.step() classify_wordseq.zero_grad() classify_model.zero_grad() epoch_finish = time.time() print("epoch: %s training finished. Time: %.2fs" % (idx, epoch_finish - epoch_start)) # _, _, _, _, f, _, _ = ner.evaluate(data, seq_wordseq, seq_model, "test") ner_score = ner.evaluate1(data, seq_wordseq, seq_model, "test") print("ner evaluate: f: %.4f" % (ner_score)) re_score = relation_extraction.evaluate(classify_wordseq, classify_model, re_test_loader) print("re evaluate: f: %.4f" % (re_score)) if ner_score + re_score > best_ner_score + best_re_score: print("new best score: ner: %.4f , re: %.4f" % (ner_score, re_score)) best_ner_score = ner_score best_re_score = re_score torch.save(seq_wordseq.state_dict(), os.path.join(ner_dir, 'wordseq.pkl')) torch.save(seq_model.state_dict(), os.path.join(ner_dir, 'model.pkl')) torch.save(classify_wordseq.state_dict(), os.path.join(re_dir, 'wordseq.pkl')) torch.save(classify_model.state_dict(), os.path.join(re_dir, 'model.pkl'))
def train(other_dir): logging.info("loading ... vocab") word_vocab = pickle.load(open(os.path.join(opt.pretrain, 'word_vocab.pkl'), 'rb')) postag_vocab = pickle.load(open(os.path.join(opt.pretrain, 'postag_vocab.pkl'), 'rb')) relation_vocab = pickle.load(open(os.path.join(opt.pretrain, 'relation_vocab.pkl'), 'rb')) entity_type_vocab = pickle.load(open(os.path.join(opt.pretrain, 'entity_type_vocab.pkl'), 'rb')) entity_vocab = pickle.load(open(os.path.join(opt.pretrain, 'entity_vocab.pkl'), 'rb')) position_vocab1 = pickle.load(open(os.path.join(opt.pretrain, 'position_vocab1.pkl'), 'rb')) position_vocab2 = pickle.load(open(os.path.join(opt.pretrain, 'position_vocab2.pkl'), 'rb')) tok_num_betw_vocab = pickle.load(open(os.path.join(opt.pretrain, 'tok_num_betw_vocab.pkl'), 'rb')) et_num_vocab = pickle.load(open(os.path.join(opt.pretrain, 'et_num_vocab.pkl'), 'rb')) my_collate = my_utils.sorted_collate if opt.model == 'lstm' else my_utils.unsorted_collate # test only on the main domain test_X = pickle.load(open(os.path.join(opt.pretrain, 'test_X.pkl'), 'rb')) test_Y = pickle.load(open(os.path.join(opt.pretrain, 'test_Y.pkl'), 'rb')) test_Other = pickle.load(open(os.path.join(opt.pretrain, 'test_Other.pkl'), 'rb')) logging.info("total test instance {}".format(len(test_Y))) test_loader = DataLoader(my_utils.RelationDataset(test_X, test_Y), opt.batch_size, shuffle=False, collate_fn=my_collate) # drop_last=True # train on the main as well as other domains domains = ['main'] train_loaders, train_iters, unk_loaders, unk_iters = {}, {}, {}, {} train_X = pickle.load(open(os.path.join(opt.pretrain, 'train_X.pkl'), 'rb')) train_Y = pickle.load(open(os.path.join(opt.pretrain, 'train_Y.pkl'), 'rb')) logging.info("total training instance {}".format(len(train_Y))) train_loaders['main'], train_iters['main'] = makeDatasetWithoutUnknown(train_X, train_Y, relation_vocab, True, my_collate) unk_loaders['main'], unk_iters['main'] = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio) for other in other_dir: domains.append(other) other_X = pickle.load(open(os.path.join(opt.pretrain, 'other_{}_X.pkl'.format(other)), 'rb')) other_Y = pickle.load(open(os.path.join(opt.pretrain, 'other_{}_Y.pkl'.format(other)), 'rb')) logging.info("other {} instance {}".format(other, len(other_Y))) train_loaders[other], train_iters[other] = makeDatasetWithoutUnknown(other_X, other_Y, relation_vocab, True, my_collate) unk_loaders[other], unk_iters[other] = makeDatasetUnknown(other_X, other_Y, relation_vocab, my_collate, opt.unk_ratio) opt.domains = domains F_s = None F_d = {} if opt.model.lower() == 'lstm': F_s = LSTMFeatureExtractor(word_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.dropout) for domain in opt.domains: F_d[domain] = LSTMFeatureExtractor(word_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.dropout) elif opt.model.lower() == 'cnn': F_s = CNNFeatureExtractor(word_vocab, postag_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) for domain in opt.domains: F_d[domain] = CNNFeatureExtractor(word_vocab, postag_vocab, position_vocab1, position_vocab2, opt.F_layers, opt.shared_hidden_size, opt.kernel_num, opt.kernel_sizes, opt.dropout) else: raise RuntimeError('Unknown feature extractor {}'.format(opt.model)) if opt.model_high == 'capsule': C = capsule.CapsuleNet(2*opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab, et_num_vocab) elif opt.model_high == 'mlp': C = baseline.MLP(2*opt.shared_hidden_size, relation_vocab, entity_type_vocab, entity_vocab, tok_num_betw_vocab, et_num_vocab) else: raise RuntimeError('Unknown model {}'.format(opt.model_high)) if opt.adv: D = DomainClassifier(1, opt.shared_hidden_size, opt.shared_hidden_size, len(opt.domains), opt.loss, opt.dropout, True) if torch.cuda.is_available(): F_s, C = F_s.cuda(opt.gpu), C.cuda(opt.gpu) for f_d in F_d.values(): f_d = f_d.cuda(opt.gpu) if opt.adv: D = D.cuda(opt.gpu) iter_parameter = itertools.chain( *map(list, [F_s.parameters() if F_s else [], C.parameters()] + [f.parameters() for f in F_d.values()])) optimizer = optim.Adam(iter_parameter, lr=opt.learning_rate) if opt.adv: optimizerD = optim.Adam(D.parameters(), lr=opt.learning_rate) best_acc = 0.0 logging.info("start training ...") for epoch in range(opt.max_epoch): F_s.train() C.train() for f in F_d.values(): f.train() if opt.adv: D.train() # domain accuracy correct, total = defaultdict(int), defaultdict(int) # D accuracy if opt.adv: d_correct, d_total = 0, 0 # conceptually view 1 epoch as 1 epoch of the main domain num_iter = len(train_loaders['main']) for i in tqdm(range(num_iter)): if opt.adv: # D iterations my_utils.freeze_net(F_s) map(my_utils.freeze_net, F_d.values()) my_utils.freeze_net(C) my_utils.unfreeze_net(D) if opt.tune_wordemb == False: my_utils.freeze_net(F_s.word_emb) for f_d in F_d.values(): my_utils.freeze_net(f_d.word_emb) # WGAN n_critic trick since D trains slower n_critic = opt.n_critic if opt.wgan_trick: if opt.n_critic > 0 and ((epoch == 0 and i < 25) or i % 500 == 0): n_critic = 100 for _ in range(n_critic): D.zero_grad() # train on both labeled and unlabeled domains for domain in opt.domains: # targets not used x2, x1, _ = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain], train_iters[domain]) d_targets = my_utils.get_domain_label(opt.loss, domain, len(x2[1])) shared_feat = F_s(x2, x1) d_outputs = D(shared_feat) # D accuracy _, pred = torch.max(d_outputs, 1) d_total += len(x2[1]) if opt.loss.lower() == 'l2': _, tgt_indices = torch.max(d_targets, 1) d_correct += (pred == tgt_indices).sum().data.item() l_d = functional.mse_loss(d_outputs, d_targets) l_d.backward() else: d_correct += (pred == d_targets).sum().data.item() l_d = functional.nll_loss(d_outputs, d_targets) l_d.backward() optimizerD.step() # F&C iteration my_utils.unfreeze_net(F_s) map(my_utils.unfreeze_net, F_d.values()) my_utils.unfreeze_net(C) if opt.adv: my_utils.freeze_net(D) if opt.tune_wordemb == False: my_utils.freeze_net(F_s.word_emb) for f_d in F_d.values(): my_utils.freeze_net(f_d.word_emb) F_s.zero_grad() for f_d in F_d.values(): f_d.zero_grad() C.zero_grad() for domain in opt.domains: x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain], train_iters[domain]) shared_feat = F_s(x2, x1) domain_feat = F_d[domain](x2, x1) features = torch.cat((shared_feat, domain_feat), dim=1) if opt.model_high == 'capsule': c_outputs, x_recon = C.forward(features, x2, x1, targets) l_c = C.loss(targets, c_outputs, features, x2, x1, x_recon, opt.lam_recon) elif opt.model_high == 'mlp': c_outputs = C.forward(features, x2, x1) l_c = C.loss(targets, c_outputs) else: raise RuntimeError('Unknown model {}'.format(opt.model_high)) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().data.item() # training with unknown x2, x1, targets = my_utils.endless_get_next_batch_without_rebatch(unk_loaders[domain], unk_iters[domain]) shared_feat = F_s(x2, x1) domain_feat = F_d[domain](x2, x1) features = torch.cat((shared_feat, domain_feat), dim=1) if opt.model_high == 'capsule': c_outputs, x_recon = C.forward(features, x2, x1, targets) l_c = C.loss(targets, c_outputs, features, x2, x1, x_recon, opt.lam_recon) elif opt.model_high == 'mlp': c_outputs = C.forward(features, x2, x1) l_c = C.loss(targets, c_outputs) else: raise RuntimeError('Unknown model {}'.format(opt.model_high)) l_c.backward(retain_graph=True) _, pred = torch.max(c_outputs, 1) total[domain] += targets.size(0) correct[domain] += (pred == targets).sum().data.item() if opt.adv: # update F with D gradients on all domains for domain in opt.domains: x2, x1, _ = my_utils.endless_get_next_batch_without_rebatch(train_loaders[domain], train_iters[domain]) shared_feat = F_s(x2, x1) d_outputs = D(shared_feat) if opt.loss.lower() == 'gr': d_targets = my_utils.get_domain_label(opt.loss, domain, len(x2[1])) l_d = functional.nll_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= -opt.lambd elif opt.loss.lower() == 'bs': d_targets = my_utils.get_random_domain_label(opt.loss, len(x2[1])) l_d = functional.kl_div(d_outputs, d_targets, size_average=False) if opt.lambd > 0: l_d *= opt.lambd elif opt.loss.lower() == 'l2': d_targets = my_utils.get_random_domain_label(opt.loss, len(x2[1])) l_d = functional.mse_loss(d_outputs, d_targets) if opt.lambd > 0: l_d *= opt.lambd l_d.backward() torch.nn.utils.clip_grad_norm_(iter_parameter, opt.grad_clip) optimizer.step() # regenerate unknown dataset after one epoch unk_loaders['main'], unk_iters['main'] = makeDatasetUnknown(train_X, train_Y, relation_vocab, my_collate, opt.unk_ratio) for other in other_dir: unk_loaders[other], unk_iters[other] = makeDatasetUnknown(other_X, other_Y, relation_vocab, my_collate, opt.unk_ratio) logging.info('epoch {} end'.format(epoch)) if opt.adv and d_total > 0: logging.info('D Training Accuracy: {}%'.format(100.0*d_correct/d_total)) logging.info('Training accuracy:') logging.info('\t'.join(opt.domains)) logging.info('\t'.join([str(100.0*correct[d]/total[d]) for d in opt.domains])) test_accuracy = evaluate(F_s, F_d['main'], C, test_loader, test_Other) logging.info('Test Accuracy: {}%'.format(test_accuracy)) if test_accuracy > best_acc: best_acc = test_accuracy torch.save(F_s.state_dict(), '{}/F_s.pth'.format(opt.output)) for d in opt.domains: torch.save(F_d[d].state_dict(), '{}/F_d_{}.pth'.format(opt.output, d)) torch.save(C.state_dict(), '{}/C.pth'.format(opt.output)) if opt.adv: torch.save(D.state_dict(), '{}/D.pth'.format(opt.output)) pickle.dump(test_Other, open(os.path.join(opt.output, 'results.pkl'), "wb"), True) logging.info('New best accuracy: {}'.format(best_acc)) logging.info("training completed")