def train(training_data, valid_data, vocabulary, embedding_dim, hidden_dim, device, batch_size, lr, embedding, all_relations, model=None, epoch=100, memory_data=[], loss_margin=0.5, alignment_model=None): if model is None: torch.manual_seed(100) model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary), np.array(embedding), 1, device) loss_function = nn.MarginRankingLoss(loss_margin) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr) best_acc = 0 memory_index = 0 for epoch_i in range(epoch): for i in range((len(training_data) - 1) // batch_size + 1): samples = training_data[i * batch_size:(i + 1) * batch_size] seed_rels = [] for item in samples: if item[0] not in seed_rels: seed_rels.append(item[0]) if len(memory_data) > 0: all_seen_data = [] for this_memory in memory_data: all_seen_data += this_memory memory_batch = memory_data[memory_index] scores, loss = feed_samples(model, memory_batch, loss_function, all_relations, device, alignment_model) optimizer.step() memory_index = (memory_index + 1) % len(memory_data) scores, loss = feed_samples(model, samples, loss_function, all_relations, device, alignment_model) optimizer.step() del scores del loss return model
def train(training_data, valid_data, vocabulary, embedding_dim, hidden_dim, device, batch_size, lr, model_path, embedding, all_relations, model=None, epoch=100, all_seen_samples=[], task_memory_size=100, loss_margin=0.5, all_seen_rels=[]): if model is None: torch.manual_seed(100) model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary), np.array(embedding), 1, device) loss_function = nn.MarginRankingLoss(loss_margin) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr) best_acc = 0 for epoch_i in range(epoch): #print('epoch', epoch_i) #training_data = training_data[0:100] for i in range((len(training_data)-1)//batch_size+1): memory_data = sample_memory_data(all_seen_samples, task_memory_size) for this_sample in memory_data: rel_cands = [rel for rel in all_seen_rels if rel!=this_sample[0]] this_sample[1] = random.sample(rel_cands, min(len(rel_cands),num_cands)) memory_data_grads = get_grads_memory_data(model, memory_data, loss_function, all_relations, device) #print(memory_data_grads) samples = training_data[i*batch_size:(i+1)*batch_size] feed_samples(model, samples, loss_function, all_relations, device) sample_grad = copy_grad_data(model) if len(memory_data_grads) > 0: if torch.matmul(memory_data_grads, torch.t(sample_grad.view(1,-1))) < 0: project2cone2(sample_grad, memory_data_grads) grad_params = get_grad_params(model) grad_dims = [param.data.numel() for param in grad_params] overwrite_grad(grad_params, sample_grad, grad_dims) optimizer.step() ''' acc=evaluate_model(model, valid_data, batch_size, all_relations, device) if acc > best_acc: torch.save(model, model_path) best_model = torch.load(model_path) return best_model ''' return model
def main(opt): print(opt) random.seed(opt.random_seed) torch.manual_seed(opt.random_seed) np.random.seed(opt.random_seed) np.random.RandomState(opt.random_seed) start_time = time.time() checkpoint_dir = os.path.join(opt.checkpoint_dir, '%.f' % start_time) device = torch.device(('cuda:%d' % opt.cuda_id) if torch.cuda.is_available() and opt.cuda_id >= 0 else 'cpu') # do following process split_train_data, train_data_dict, split_test_data, split_valid_data, relation_numbers, rel_features, \ split_train_relations, vocabulary, embedding = \ load_data(opt.train_file, opt.valid_file, opt.test_file, opt.relation_file, opt.glove_file, opt.embedding_dim, opt.task_arrange, opt.rel_encode, opt.task_num, opt.train_instance_num, opt.dataset) print(split_train_relations) # ------------------------------------------------------------------------ # save cluster results our_tasks = split_train_relations count = 0 for i in our_tasks: count += len(i) portion = np.zeros(10000) portion = portion-1 for i in range(len(our_tasks)): for j in our_tasks[i]: portion[j - 1] = int(i) np.save("dataset/tacred/CML_tacred_random.npy", np.array(portion).astype(int)) # ------------------------------------------------------------------------- print('\n'.join( ['Task %d\t%s' % (index, ', '.join(['%d' % rel for rel in split_train_relations[index]])) for index in range(len(split_train_relations))])) task_sequence = list(range(opt.task_num)) if opt.random_idx: for i in range(opt.random_times): random.shuffle(task_sequence) offset_seq = task_sequence[-opt.sequence_index:] + task_sequence[:-opt.sequence_index] split_train_data = resort_list(split_train_data, offset_seq) split_test_data = resort_list(split_test_data, offset_seq) split_valid_data = resort_list(split_valid_data, offset_seq) split_train_relations = resort_list(split_train_relations, offset_seq) print('[%s]' % ', '.join(['Task %d' % idx for idx in offset_seq])) relid2embedidx = {} embedidx2relid = {} if opt.similarity == 'kl_similarity': kl_dist_ht = read_json(opt.kl_dist_file) sorted_similarity_index = np.argsort(np.asarray(kl_dist_ht), axis=1) + 1 elif opt.similarity == 'glove_similarity': glove_embedding = [] embed_id = 0 for rel_id in rel_features: glove_embedding.append(rel_features[rel_id]) relid2embedidx[rel_id] = embed_id embedidx2relid[embed_id] = rel_id embed_id += 1 glove_similarity = cosine_similarity(np.asarray(glove_embedding)) glove_dist = np.sqrt(1 - np.power(np.where(glove_similarity > 1.0, 1.0, glove_similarity), 2)) sorted_embed_index = np.argsort(np.asarray(glove_dist), axis=1) sorted_similarity_index = np.zeros(sorted_embed_index.shape) for i in range(sorted_embed_index.shape[0]): for j in range(sorted_embed_index.shape[1]): sorted_similarity_index[i][j] = embedidx2relid[sorted_embed_index[i][j]] else: raise Exception('similarity method not implemented') # prepare model inner_model = SimilarityModel(opt.embedding_dim, opt.hidden_dim, len(vocabulary), np.array(embedding), 1, device) memory_data = [] memory_pool = [] memory_question_embed = [] memory_relation_embed = [] sequence_results = [] result_whole_test = [] seen_relations = [] all_seen_relations = [] rel2instance_memory = {} memory_index = 0 seen_task_relations = [] rel_embeddings = [] for task_ix in range(opt.task_num): # outside loop # reptile start model parameters pi weights_before = deepcopy(inner_model.state_dict()) train_task = split_train_data[task_ix] test_task = split_test_data[task_ix] valid_task = split_valid_data[task_ix] train_relations = split_train_relations[task_ix] seen_task_relations.append(train_relations) # collect seen relations for data_item in train_task: if data_item[0] not in seen_relations: seen_relations.append(data_item[0]) # remove unseen relations current_train_data = remove_unseen_relation(train_task, seen_relations, dataset=opt.dataset) current_valid_data = remove_unseen_relation(valid_task, seen_relations, dataset=opt.dataset) current_test_data = [] for previous_task_id in range(task_ix + 1): current_test_data.append( remove_unseen_relation(split_test_data[previous_task_id], seen_relations, dataset=opt.dataset)) for this_sample in current_train_data: if this_sample[0] not in all_seen_relations: all_seen_relations.append(this_sample[0]) update_rel_cands(memory_data, all_seen_relations, opt.num_cands) # train inner_model loss_function = nn.MarginRankingLoss(opt.loss_margin) inner_model = inner_model.to(device) optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate) t = tqdm(range(opt.outside_epoch)) best_valid_acc = 0.0 early_stop = 0 best_checkpoint = '' # resorted_memory_pool = [] for epoch in t: batch_num = (len(current_train_data) - 1) // opt.batch_size + 1 total_loss = 0.0 target_rel = -1 for batch in range(batch_num): batch_train_data = current_train_data[batch * opt.batch_size: (batch + 1) * opt.batch_size] if len(memory_data) > 0: # CML if target_rel == -1 or len(resorted_memory_pool) == 0: target_rel = batch_train_data[0][0] if opt.similarity == 'kl_similarity': target_rel_sorted_index = sorted_similarity_index[target_rel - 1] else: target_rel_sorted_index = sorted_similarity_index[relid2embedidx[target_rel]] resorted_memory_pool = resort_memory(memory_pool, target_rel_sorted_index) if len(resorted_memory_pool) >= opt.task_memory_size: current_memory = resorted_memory_pool[:opt.task_memory_size] resorted_memory_pool = resorted_memory_pool[opt.task_memory_size + 1:] # update rest memory batch_train_data.extend(current_memory) else: current_memory = resorted_memory_pool[:] resorted_memory_pool = [] batch_train_data.extend(current_memory) # MLLRE # all_seen_data = [] # for one_batch_memory in memory_data: # all_seen_data += one_batch_memory # # memory_batch = memory_data[memory_index] # batch_train_data.extend(memory_batch) # scores, loss = feed_samples(inner_model, memory_batch, loss_function, relation_numbers, device) # optimizer.step() # memory_index = (memory_index+1) % len(memory_data) if len(rel2instance_memory) > 0: # from the second task, this will not be empty if opt.is_curriculum_train == 'Y': current_train_rel = batch_train_data[0][0] current_rel_similarity_sorted_index = sorted_similarity_index[current_train_rel + 1] seen_relation_sorted_index = [] for rel in current_rel_similarity_sorted_index: if rel in rel2instance_memory.keys(): seen_relation_sorted_index.append(rel) curriculum_rel_list = [] if opt.sampled_rel_num >= len(seen_relation_sorted_index): curriculum_rel_list = seen_relation_sorted_index[:] else: step = len(seen_relation_sorted_index) // opt.sampled_rel_num for i in range(0, len(seen_relation_sorted_index), step): curriculum_rel_list.append(seen_relation_sorted_index[i]) # curriculum select relation instance_list = [] for sampled_relation in curriculum_rel_list: if opt.mini_batch_split == 'Y': instance_list.append(rel2instance_memory[sampled_relation]) else: instance_list.extend(rel2instance_memory[sampled_relation]) else: # randomly select relation instance_list = [] random_relation_list = random.sample(list(rel2instance_memory.keys()), min(opt.sampled_rel_num, len(rel2instance_memory))) for sampled_relation in random_relation_list: if opt.mini_batch_split == 'Y': instance_list.append(rel2instance_memory[sampled_relation]) else: instance_list.extend(rel2instance_memory[sampled_relation]) if opt.mini_batch_split == 'Y': for one_batch_instance in instance_list: scores, loss = feed_samples(inner_model, one_batch_instance, loss_function, relation_numbers, device, all_seen_relations) optimizer.step() else: scores, loss = feed_samples(inner_model, instance_list, loss_function, relation_numbers, device, all_seen_relations) optimizer.step() scores, loss = feed_samples(inner_model, batch_train_data, loss_function, relation_numbers, device, all_seen_relations) optimizer.step() total_loss += loss # valid test valid_acc = evaluate_model(inner_model, current_valid_data, opt.batch_size, relation_numbers, device) # checkpoint checkpoint = {'net_state': inner_model.state_dict(), 'optimizer': optimizer.state_dict()} if valid_acc > best_valid_acc: best_checkpoint = '%s/checkpoint_task%d_epoch%d.pth.tar' % (checkpoint_dir, task_ix + 1, epoch) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) torch.save(checkpoint, best_checkpoint) best_valid_acc = valid_acc early_stop = 0 else: early_stop += 1 # print() t.set_description('Task %i Epoch %i' % (task_ix + 1, epoch + 1)) t.set_postfix(loss=total_loss.item(), valid_acc=valid_acc, early_stop=early_stop, best_checkpoint=best_checkpoint) t.update(1) if early_stop >= opt.early_stop and task_ix != 0: # convergence break if task_ix == 0 and early_stop >= 40: break t.close() print('Load best check point from %s' % best_checkpoint) checkpoint = torch.load(best_checkpoint) weights_after = checkpoint['net_state'] if opt.outer_step_formula == 'fixed': outer_step_size = opt.step_size elif opt.outer_step_formula == 'linear': outer_step_size = opt.step_size * (1 - task_ix / opt.task_num) elif opt.outer_step_formula == 'square_root': outer_step_size = math.sqrt(opt.step_size * (1 - task_ix / opt.task_num)) # outer_step_size = 0.4 inner_model.load_state_dict( {name: weights_before[name] + (weights_after[name] - weights_before[name]) * outer_step_size for name in weights_before}) results = [evaluate_model(inner_model, test_data, opt.batch_size, relation_numbers, device) for test_data in current_test_data] # sample memory from current_train_data if opt.memory_select_method == 'select_for_relation': # sample instance for one relation for rel in train_relations: rel_items = remove_unseen_relation(train_data_dict[rel], seen_relations, dataset=opt.dataset) rel_memo = select_data(inner_model, rel_items, int(opt.sampled_instance_num), relation_numbers, opt.batch_size, device) rel2instance_memory[rel] = rel_memo if opt.memory_select_method == 'select_for_task': # sample instance for one Task rel_instance_num = math.ceil(opt.sampled_instance_num_total / len(train_relations)) for rel in train_relations: rel_items = remove_unseen_relation(train_data_dict[rel], seen_relations, dataset=opt.dataset) rel_memo = select_data(inner_model, rel_items, rel_instance_num, relation_numbers, opt.batch_size, device) rel2instance_memory[rel] = rel_memo if opt.task_memory_size > 0: # sample memory from current_train_data if opt.memory_select_method == 'random': memory_data.append(random_select_data(current_train_data, int(opt.task_memory_size))) elif opt.memory_select_method == 'vec_cluster': selected_memo = select_data(inner_model, current_train_data, int(opt.task_memory_size), relation_numbers, opt.batch_size, device) memory_data.append(selected_memo) # memorydata-list memory_pool.extend(selected_memo) elif opt.memory_select_method == 'difficulty': memory_data.append() print_list(results) avg_result = sum(results) / len(results) test_set_size = [len(testdata) for testdata in current_test_data] whole_result = sum([results[i] * test_set_size[i] for i in range(len(current_test_data))]) / sum(test_set_size) print('test_set_size: [%s]' % ', '.join([str(size) for size in test_set_size])) print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result)) print('test_all:') result_total_for_avg = [] result_total_for_whole = [] for epoch in range(10): current_test_data = [] for previous_task_id in range(opt.task_num): current_test_data.append( remove_unseen_relation(split_test_data[previous_task_id], seen_relations, dataset=opt.dataset)) loss_function = nn.MarginRankingLoss(opt.loss_margin) optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate) optimizer.zero_grad() for one_batch_memory in memory_data: scores, loss = feed_samples(inner_model, one_batch_memory, loss_function, relation_numbers, device, all_seen_relations) optimizer.step() results = [evaluate_model(inner_model, test_data, opt.batch_size, relation_numbers, device) for test_data in current_test_data] print(results) avg_result = sum(results) / len(results) test_set_size = [len(testdata) for testdata in current_test_data] whole_result = sum([results[i] * test_set_size[i] for i in range(len(current_test_data))]) / sum(test_set_size) print('test_set_size: [%s]' % ', '.join([str(size) for size in test_set_size])) print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result)) result_total_for_avg.append(results) result_total_for_whole.append(whole_result) # clean saved parameters files = os.listdir(checkpoint_dir) for weigths_file in files: os.remove(os.path.join(checkpoint_dir, weigths_file)) os.removedirs(checkpoint_dir) # ----------------------------------------------------------- # 输出结果 avg_total = np.mean(np.array(result_total_for_avg), 1) avg_mean, avg_interval = interval(avg_total) whole_mean, whole_interval = interval(np.array(result_total_for_whole)) result_total = {"avg_acc": result_total_for_avg, "whole_acc": result_total_for_whole, "avg_mean": avg_mean, "avg_interval": avg_interval.tolist(), "whole_mean": whole_mean, "whole_interval": whole_interval.tolist()} print(result_total) with open(opt.result_file, "w") as file_out: json.dump(result_total, file_out)
def main(opt): print(opt) # print('线性outer step formula,0.6 step size, 每task聚类取50个memo') random.seed(opt.random_seed) torch.manual_seed(opt.random_seed) np.random.seed(opt.random_seed) np.random.RandomState(opt.random_seed) start_time = time.time() checkpoint_dir = os.path.join(opt.checkpoint_dir, '%.f' % start_time) device = torch.device(( 'cuda:%d' % opt.cuda_id ) if torch.cuda.is_available() and opt.cuda_id >= 0 else 'cpu') # do following process split_train_data, train_data_dict, split_test_data, split_valid_data, relation_numbers, rel_features, \ split_train_relations, vocabulary, embedding = \ load_data(opt.train_file, opt.valid_file, opt.test_file, opt.relation_file, opt.glove_file, opt.embedding_dim, opt.task_arrange, opt.rel_encode, opt.task_num, opt.train_instance_num, opt.dataset) print('\n'.join([ 'Task %d\t%s' % (index, ', '.join(['%d' % rel for rel in split_train_relations[index]])) for index in range(len(split_train_relations)) ])) # offset tasks # split_train_data = offset_list(split_train_data, opt.task_offset) # split_test_data = offset_list(split_test_data, opt.task_offset) # split_valid_data = offset_list(split_valid_data, opt.task_offset) # task_sq = [None] * len(split_train_relations) # for i in range(len(split_train_relations)): # task_sq[(i + opt.task_offset) % len(split_train_relations)] = i # print('[%s]' % ', '.join(['Task %d' % i for i in task_sq])) # insert 6th-task # task_index = [[6, 0, 1, 2, 3, 4, 5, 7, 8, 9], # [0, 6, 1, 2, 3, 4, 5, 7, 8, 9], # [0, 1, 6, 2, 3, 4, 5, 7, 8, 9], # [0, 1, 2, 6, 3, 4, 5, 7, 8, 9], # [0, 1, 2, 3, 6, 4, 5, 7, 8, 9], # [0, 1, 2, 3, 4, 6, 5, 7, 8, 9], # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], # [0, 1, 2, 3, 4, 5, 7, 6, 8, 9], # [0, 1, 2, 3, 4, 5, 7, 8, 6, 9], # [0, 1, 2, 3, 4, 5, 7, 8, 9, 6]] # task_sequence = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], # [9, 0, 1, 2, 3, 4, 5, 6, 7, 8], # [8, 9, 0, 1, 2, 3, 4, 5, 6, 7], # [7, 8, 9, 0, 1, 2, 3, 4, 5, 6], # [6, 7, 8, 9, 0, 1, 2, 3, 4, 5], # [5, 6, 7, 8, 9, 0, 1, 2, 3, 4], # [4, 5, 6, 7, 8, 9, 0, 1, 2, 3], # [3, 4, 5, 6, 7, 8, 9, 0, 1, 2], # [2, 3, 4, 5, 6, 7, 8, 9, 0, 1], # [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]] task_sequence = list(range(opt.task_num)) if opt.random_idx: for i in range(opt.random_times): random.shuffle(task_sequence) offset_seq = task_sequence[ -opt.sequence_index:] + task_sequence[:-opt.sequence_index] # task_sequence = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], # [9, 0, 1, 2, 3, 4, 5, 6, 7, 8], # [8, 9, 0, 1, 2, 3, 4, 5, 6, 7], # [7, 8, 9, 0, 1, 2, 3, 4, 5, 6], # [6, 7, 8, 9, 0, 1, 2, 3, 4, 5], # [5, 6, 7, 8, 9, 0, 1, 2, 3, 4], # [4, 5, 6, 7, 8, 9, 0, 1, 2, 3], # [3, 4, 5, 6, 7, 8, 9, 0, 1, 2], # [2, 3, 4, 5, 6, 7, 8, 9, 0, 1], # [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]] # # if opt.random_idx: # random_idx = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # for i in range(opt.random_times): # random.shuffle(random_idx) # print(str(random_idx)) # # for i in range(len(task_sequence)): # new_sq = [0] * len(task_sequence[i]) # for j in range(len(task_sequence[i])): # new_sq[j] = random_idx[task_sequence[i][j]] # task_sequence[i] = new_sq split_train_data = resort_list(split_train_data, offset_seq) split_test_data = resort_list(split_test_data, offset_seq) split_valid_data = resort_list(split_valid_data, offset_seq) split_train_relations = resort_list(split_train_relations, offset_seq) print('[%s]' % ', '.join(['Task %d' % idx for idx in offset_seq])) relid2embedidx = {} embedidx2relid = {} if opt.similarity == 'kl_similarity': kl_dist_ht = read_json(opt.kl_dist_file) # tmp = [[0, 1, 2, 3], [1, 0, 4, 6], [2, 4, 0, 5], [3, 6, 5, 0]] sorted_similarity_index = np.argsort(np.asarray(kl_dist_ht), axis=1) + 1 elif opt.similarity == 'glove_similarity': glove_embedding = [] embed_id = 0 for rel_id in rel_features: glove_embedding.append(rel_features[rel_id]) relid2embedidx[rel_id] = embed_id embedidx2relid[embed_id] = rel_id embed_id += 1 glove_similarity = cosine_similarity(np.asarray(glove_embedding)) glove_dist = np.sqrt(1 - np.power( np.where(glove_similarity > 1.0, 1.0, glove_similarity), 2)) sorted_embed_index = np.argsort(np.asarray(glove_dist), axis=1) sorted_similarity_index = np.zeros(sorted_embed_index.shape) for i in range(sorted_embed_index.shape[0]): for j in range(sorted_embed_index.shape[1]): sorted_similarity_index[i][j] = embedidx2relid[ sorted_embed_index[i][j]] # print() # for i in range(1, len(rel_features) + 1): # rel_embed = rel_features[i] # 1-80 # glove_embedding.append(rel_embed) # index 0-79, total 80 # # glove_similarity = cosine_similarity(np.asarray(glove_embedding)) # glove_dist = np.sqrt(1 - np.power(np.where(glove_similarity > 1.0, 1.0, glove_similarity), 2)) # sorted_similarity_index = np.argsort(np.asarray(glove_dist), axis=1) + 1 else: raise Exception('similarity method not implemented') # prepare model inner_model = SimilarityModel(opt.embedding_dim, opt.hidden_dim, len(vocabulary), np.array(embedding), 1, device) memory_data = [] memory_pool = [] memory_question_embed = [] memory_relation_embed = [] sequence_results = [] result_whole_test = [] seen_relations = [] all_seen_relations = [] rel2instance_memory = {} memory_index = 0 seen_task_relations = [] rel_embeddings = [] for task_ix in range(opt.task_num): # outside loop # reptile start model parameters pi weights_before = deepcopy(inner_model.state_dict()) train_task = split_train_data[task_ix] test_task = split_test_data[task_ix] valid_task = split_valid_data[task_ix] train_relations = split_train_relations[task_ix] seen_task_relations.append(train_relations) # collect seen relations for data_item in train_task: if data_item[0] not in seen_relations: seen_relations.append(data_item[0]) # remove unseen relations current_train_data = remove_unseen_relation(train_task, seen_relations, dataset=opt.dataset) current_valid_data = remove_unseen_relation(valid_task, seen_relations, dataset=opt.dataset) current_test_data = [] for previous_task_id in range(task_ix + 1): current_test_data.append( remove_unseen_relation(split_test_data[previous_task_id], seen_relations, dataset=opt.dataset)) for this_sample in current_train_data: if this_sample[0] not in all_seen_relations: all_seen_relations.append(this_sample[0]) update_rel_cands(memory_data, all_seen_relations, opt.num_cands) # train inner_model loss_function = nn.MarginRankingLoss(opt.loss_margin) inner_model = inner_model.to(device) optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate) t = tqdm(range(opt.outside_epoch)) best_valid_acc = 0.0 early_stop = 0 best_checkpoint = '' # resorted_memory_pool = [] for epoch in t: batch_num = (len(current_train_data) - 1) // opt.batch_size + 1 total_loss = 0.0 target_rel = -1 for batch in range(batch_num): batch_train_data = current_train_data[batch * opt.batch_size:(batch + 1) * opt.batch_size] if len(memory_data) > 0: # curriculum select and organize memory if target_rel == -1 or len(resorted_memory_pool) == 0: target_rel = batch_train_data[0][0] if opt.similarity == 'kl_similarity': target_rel_sorted_index = sorted_similarity_index[ target_rel - 1] else: target_rel_sorted_index = sorted_similarity_index[ relid2embedidx[target_rel]] resorted_memory_pool = resort_memory( memory_pool, target_rel_sorted_index) if len(resorted_memory_pool) >= opt.task_memory_size: current_memory = resorted_memory_pool[:opt. task_memory_size] resorted_memory_pool = resorted_memory_pool[ opt.task_memory_size + 1:] # 更新剩余的memory batch_train_data.extend(current_memory) else: current_memory = resorted_memory_pool[:] resorted_memory_pool = [] # 更新剩余的memory batch_train_data.extend(current_memory) # 淘汰的做法 # if len(resorted_memory_pool) != 0: # current_memory = resorted_memory_pool[:opt.task_memory_size] # resorted_memory_pool = resorted_memory_pool[opt.task_memory_size + 1:] # 更新剩余的memory # batch_train_data.extend(current_memory) # else: # target_rel = batch_train_data[0][0] # target_rel_sorted_index = sorted_similarity_index[target_rel - 1] # resorted_memory_pool = resort_memory(memory_pool, target_rel_sorted_index) # MLLRE的做法 # all_seen_data = [] # for one_batch_memory in memory_data: # all_seen_data += one_batch_memory # # memory_batch = memory_data[memory_index] # batch_train_data.extend(memory_batch) # scores, loss = feed_samples(inner_model, memory_batch, loss_function, relation_numbers, device) # optimizer.step() # memory_index = (memory_index+1) % len(memory_data) # random.shuffle(batch_train_data) if len(rel2instance_memory ) > 0: # from the second task, this will not be empty if opt.is_curriculum_train == 'Y': current_train_rel = batch_train_data[0][0] current_rel_similarity_sorted_index = sorted_similarity_index[ current_train_rel + 1] seen_relation_sorted_index = [] for rel in current_rel_similarity_sorted_index: if rel in rel2instance_memory.keys(): seen_relation_sorted_index.append(rel) curriculum_rel_list = [] if opt.sampled_rel_num >= len( seen_relation_sorted_index): curriculum_rel_list = seen_relation_sorted_index[:] else: step = len(seen_relation_sorted_index ) // opt.sampled_rel_num for i in range(0, len(seen_relation_sorted_index), step): curriculum_rel_list.append( seen_relation_sorted_index[i]) # curriculum select relation instance_list = [] for sampled_relation in curriculum_rel_list: if opt.mini_batch_split == 'Y': instance_list.append( rel2instance_memory[sampled_relation]) else: instance_list.extend( rel2instance_memory[sampled_relation]) else: # randomly select relation instance_list = [] random_relation_list = random.sample( list(rel2instance_memory.keys()), min(opt.sampled_rel_num, len(rel2instance_memory))) for sampled_relation in random_relation_list: if opt.mini_batch_split == 'Y': instance_list.append( rel2instance_memory[sampled_relation]) else: instance_list.extend( rel2instance_memory[sampled_relation]) if opt.mini_batch_split == 'Y': for one_batch_instance in instance_list: # curriculum_instance_list = remove_unseen_relation(curriculum_instance_list, seen_relations) scores, loss = feed_samples( inner_model, one_batch_instance, loss_function, relation_numbers, device) optimizer.step() else: # curriculum_instance_list = remove_unseen_relation(curriculum_instance_list, seen_relations) scores, loss = feed_samples(inner_model, instance_list, loss_function, relation_numbers, device) optimizer.step() scores, loss = feed_samples(inner_model, batch_train_data, loss_function, relation_numbers, device) optimizer.step() total_loss += loss # valid test valid_acc = evaluate_model(inner_model, current_valid_data, opt.batch_size, relation_numbers, device) # checkpoint checkpoint = { 'net_state': inner_model.state_dict(), 'optimizer': optimizer.state_dict() } if valid_acc > best_valid_acc: best_checkpoint = '%s/checkpoint_task%d_epoch%d.pth.tar' % ( checkpoint_dir, task_ix + 1, epoch) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) torch.save(checkpoint, best_checkpoint) best_valid_acc = valid_acc early_stop = 0 else: early_stop += 1 # print() t.set_description('Task %i Epoch %i' % (task_ix + 1, epoch + 1)) t.set_postfix(loss=total_loss.item(), valid_acc=valid_acc, early_stop=early_stop, best_checkpoint=best_checkpoint) t.update(1) if early_stop >= opt.early_stop and task_ix != 0: # 已经充分训练了 break if task_ix == 0 and early_stop >= 40: # 防止大数据量的task在第一轮得不到充分训练 break t.close() print('Load best check point from %s' % best_checkpoint) checkpoint = torch.load(best_checkpoint) weights_after = checkpoint['net_state'] # weights_after = inner_model.state_dict() # 经过inner_epoch轮次的梯度更新后weights # outer_step_size = opt.step_size * (1 - task_index / opt.task_num) # linear schedule # if outer_step_size < opt.step_size * 0.5: # outer_step_size = opt.step_size * 0.5 if opt.outer_step_formula == 'fixed': outer_step_size = opt.step_size elif opt.outer_step_formula == 'linear': outer_step_size = opt.step_size * (1 - task_ix / opt.task_num) elif opt.outer_step_formula == 'square_root': outer_step_size = math.sqrt(opt.step_size * (1 - task_ix / opt.task_num)) # outer_step_size = 0.4 inner_model.load_state_dict({ name: weights_before[name] + (weights_after[name] - weights_before[name]) * outer_step_size for name in weights_before }) # 用memory进行训练: # for i in range(5): # for one_batch_memory in memory_data: # scores, loss = feed_samples(inner_model, one_batch_memory, loss_function, relation_numbers, device) # optimizer.step() results = [ evaluate_model(inner_model, test_data, opt.batch_size, relation_numbers, device) for test_data in current_test_data ] # 使用current model和alignment model对test data进行一个预测 # sample memory from current_train_data if opt.memory_select_method == 'select_for_relation': # 每个关系sample k个 for rel in train_relations: rel_items = remove_unseen_relation(train_data_dict[rel], seen_relations, dataset=opt.dataset) rel_memo = select_data(inner_model, rel_items, int(opt.sampled_instance_num), relation_numbers, opt.batch_size, device) rel2instance_memory[rel] = rel_memo if opt.memory_select_method == 'select_for_task': # 为每个task sample k个 rel_instance_num = math.ceil(opt.sampled_instance_num_total / len(train_relations)) for rel in train_relations: rel_items = remove_unseen_relation(train_data_dict[rel], seen_relations, dataset=opt.dataset) rel_memo = select_data(inner_model, rel_items, rel_instance_num, relation_numbers, opt.batch_size, device) rel2instance_memory[rel] = rel_memo if opt.task_memory_size > 0: # sample memory from current_train_data if opt.memory_select_method == 'random': memory_data.append( random_select_data(current_train_data, int(opt.task_memory_size))) elif opt.memory_select_method == 'vec_cluster': selected_memo = select_data(inner_model, current_train_data, int(opt.task_memory_size), relation_numbers, opt.batch_size, device) memory_data.append( selected_memo ) # memorydata是一个list,list中的每个元素都是一个包含selected_num个sample的list memory_pool.extend(selected_memo) elif opt.memory_select_method == 'difficulty': memory_data.append() print_list(results) avg_result = sum(results) / len(results) test_set_size = [len(testdata) for testdata in current_test_data] whole_result = sum([ results[i] * test_set_size[i] for i in range(len(current_test_data)) ]) / sum(test_set_size) print('test_set_size: [%s]' % ', '.join([str(size) for size in test_set_size])) print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result)) # end of each task, get embeddings of all # if len(all_seen_relations) > 1: # rel_embed = tsne_relations(inner_model, seen_task_relations, relation_numbers, device, task_sequence[opt.sequence_index]) # rel_embeddings.append(rel_embed) print('test_all:') for epoch in range(10): current_test_data = [] for previous_task_id in range(opt.task_num): current_test_data.append( remove_unseen_relation(split_test_data[previous_task_id], seen_relations, dataset=opt.dataset)) loss_function = nn.MarginRankingLoss(opt.loss_margin) optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate) optimizer.zero_grad() for one_batch_memory in memory_data: scores, loss = feed_samples(inner_model, one_batch_memory, loss_function, relation_numbers, device) optimizer.step() results = [ evaluate_model(inner_model, test_data, opt.batch_size, relation_numbers, device) for test_data in current_test_data ] print(results) avg_result = sum(results) / len(results) test_set_size = [len(testdata) for testdata in current_test_data] whole_result = sum([ results[i] * test_set_size[i] for i in range(len(current_test_data)) ]) / sum(test_set_size) print('test_set_size: [%s]' % ', '.join([str(size) for size in test_set_size])) print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--cuda_id', default=0, type=int, help='cuda device index, -1 means use cpu') parser.add_argument('--train_file', default='dataset/training_data.txt', help='train file') parser.add_argument('--valid_file', default='dataset/val_data.txt', help='valid file') parser.add_argument('--test_file', default='dataset/val_data.txt', help='test file') parser.add_argument('--relation_file', default='dataset/relation_name.txt', help='relation name file') parser.add_argument('--glove_file', default='dataset/glove.6B.50d.txt', help='glove embedding file') parser.add_argument('--embedding_dim', default=50, type=int, help='word embeddings dimensional') parser.add_argument('--hidden_dim', default=200, type=int, help='BiLSTM hidden dimensional') parser.add_argument( '--task_arrange', default='random', help='task arrangement method, e.g. cluster_by_glove_embedding, random' ) parser.add_argument('--rel_encode', default='glove', help='relation encode method') parser.add_argument( '--meta_method', default='reptile', help='meta learning method, maml and reptile can be choose') parser.add_argument('--batch_size', default=50, type=float, help='Reptile inner loop batch size') parser.add_argument('--task_num', default=10, type=int, help='number of tasks') parser.add_argument( '--train_instance_num', default=200, type=int, help='number of instances for one relation, -1 means all.') parser.add_argument('--loss_margin', default=0.5, type=float, help='loss margin setting') parser.add_argument('--outside_epoch', default=200, type=float, help='task level epoch') parser.add_argument('--early_stop', default=20, type=float, help='task level epoch') parser.add_argument('--step_size', default=0.4, type=float, help='step size Epsilon') parser.add_argument('--learning_rate', default=2e-2, type=float, help='learning rate') parser.add_argument('--random_seed', default=317, type=int, help='random seed') parser.add_argument('--task_memory_size', default=50, type=int, help='number of samples for each task') parser.add_argument( '--memory_select_method', default='vec_cluster', help= 'the method of sample memory data, e.g. vec_cluster, random, difficulty' ) opt = parser.parse_args() print(opt) random.seed(opt.random_seed) torch.manual_seed(opt.random_seed) np.random.seed(opt.random_seed) np.random.RandomState(opt.random_seed) device = torch.device(( 'cuda:%d' % opt.cuda_id ) if torch.cuda.is_available() and opt.cuda_id >= 0 else 'cpu') # do following process split_train_data, train_data_dict, split_test_data, test_data_dict, split_valid_data, valid_data_dict, \ relation_numbers, rel_features, vocabulary, embedding = \ load_data(opt.train_file, opt.valid_file, opt.test_file, opt.relation_file, opt.glove_file, opt.embedding_dim, opt.task_arrange, opt.rel_encode, opt.task_num, opt.train_instance_num) # prepare model inner_model = SimilarityModel(opt.embedding_dim, opt.hidden_dim, len(vocabulary), np.array(embedding), 1, device) memory_data = [] # B seen_relations = [] for task_index in range(opt.task_num): # outside loop # reptile start model parameters pi weights_before = deepcopy(inner_model.state_dict()) train_task = split_train_data[task_index] # test_task = split_test_data[task_index] valid_task = split_valid_data[task_index] # collect seen relations for data_item in train_task: if data_item[0] not in seen_relations: seen_relations.append(data_item[0]) # remove unseen relations current_train_data = remove_unseen_relation(train_task, seen_relations) current_valid_data = remove_unseen_relation(valid_task, seen_relations) current_test_data = [] for previous_task_id in range(task_index + 1): current_test_data.append( remove_unseen_relation(split_test_data[previous_task_id], seen_relations)) # train inner_model loss_function = nn.MarginRankingLoss(opt.loss_margin) inner_model = inner_model.to(device) t = tqdm(range(opt.outside_epoch)) best_valid_acc = 0.0 early_stop = 0 best_checkpoint = '' for epoch in t: weights_task = deepcopy(weights_before) # optimizer.zero_grad() # inner_model.load_state_dict({name: weights_before[name] for name in weights_before}) batch_num = (len(current_train_data) - 1) // opt.train_instance_num + 1 total_loss = 0.0 weights_list = [None] * (batch_num + len(memory_data)) for batch in range(batch_num): # one relation's train data batch_train_data = current_train_data[batch * opt.train_instance_num: (batch + 1) * opt.train_instance_num] inner_model.load_state_dict(weights_task) optimizer = optim.SGD(inner_model.parameters(), lr=opt.learning_rate) optimizer.zero_grad() scores, loss = feed_samples(inner_model, batch_train_data, loss_function, relation_numbers, device) loss.backward() # 计算反向传播梯度 # 更新参数 # for f in inner_model.parameters(): # f.data.sub_(f.grad.data * opt.learning_rate) optimizer.step() # 更新参数 total_loss += loss weights_list[batch] = deepcopy( inner_model.state_dict()) # 保存theta_t^i if len(memory_data) > 0: for i in range(len(memory_data)): one_batch_memory = memory_data[i] inner_model.load_state_dict(weights_task) optimizer = optim.SGD(inner_model.parameters(), lr=opt.learning_rate) optimizer.zero_grad() scores, loss = feed_samples(inner_model, one_batch_memory, loss_function, relation_numbers, device) loss.backward() # 更新参数 # for f in inner_model.parameters(): # f.data.sub_(f.grad.data * opt.learning_rate) optimizer.step() total_loss += loss weights_list[batch_num + i] = deepcopy( inner_model.state_dict()) outer_step_size = opt.step_size * (1 / len(weights_list)) for name in weights_before: weights_task[ name] = weights_before[name] - outer_step_size * sum([ weights[name] - weights_before[name] for weights in weights_list ]) # load state dict of weights_after inner_model.load_state_dict(weights_task) # weights_before = deepcopy(inner_model.state_dict()) del weights_list valid_acc = evaluate_model(inner_model, current_valid_data, opt.batch_size, relation_numbers, device) checkpoint = {'net_state': inner_model.state_dict()} if valid_acc > best_valid_acc: best_checkpoint = './checkpoint/checkpoint_task%d_epoch%d.pth.tar' % ( task_index, epoch) torch.save(checkpoint, best_checkpoint) best_valid_acc = valid_acc early_stop = 0 else: early_stop += 1 # print() t.set_description('Task %i Epoch %i' % (task_index + 1, epoch + 1)) t.set_postfix(loss=total_loss.item(), valid_acc=valid_acc, early_stop=early_stop, best_checkpoint=best_checkpoint) t.update(1) if early_stop >= opt.early_stop: # 已经充分训练了 break t.close() # sample memory from current_train_data if opt.memory_select_method == 'random': memory_data.append( random_select_data(current_train_data, opt.task_memory_size)) elif opt.memory_select_method == 'vec_cluster': memory_data.append( select_data(inner_model, current_train_data, opt.task_memory_size, relation_numbers, opt.batch_size, device) ) # memorydata是一个list,list中的每个元素都是一个包含selected_num个sample的list elif opt.memory_select_method == 'difficulty': memory_data.append() results = [ evaluate_model(inner_model, test_data, opt.batch_size, relation_numbers, device) for test_data in current_test_data ] # 使用current model和alignment model对test data进行一个预测 print(results)
transform_list = transforms.Compose(transform_list) dset_train = MultiViewDataSet(args.data, transform=transform_list) train_loader = DataLoader(dset_train, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True) classes = dset_train.classes # print(len(classes)) model = Model(class_num=len(classes), stride=2) model = model.cuda() cudnn.benchmark = True sim_model = SimilarityModel() sim_model = sim_model.cuda() classifier = model.classifier.classifier # logger = Logger('logs') # Loss and Optimizer lr = args.lr n_epochs = args.epochs criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss() similarity_criterion = SimilarityLoss() ignored_params = list(map(id, model.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) optimizer = torch.optim.SGD([{ 'params': base_params,
def train_memory(training_data, valid_data, vocabulary, embedding_dim, hidden_dim, device, batch_size, lr, model_path, embedding, all_relations, model=None, epoch=100, memory_data=[], loss_margin=0.5, past_fisher=None, rel_samples=[], relation_frequences=[], rel_embeds=None, rel_ques_cand=None, rel_acc_diff=None, all_seen_rels=None, update_rel_embed=None, reverse_model=None, memory_que_embed=[], memory_rel_embed=[], to_update_reverse=False): if model is None: torch.manual_seed(100) model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary), np.array(embedding), 1, device) loss_function = nn.MarginRankingLoss(loss_margin) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr) reverse_optimiser = torch.optim.Adam(reverse_model.parameters(), lr=lr) best_acc = 0 #acc_pre=evaluate_model(model, valid_data, batch_size, all_relations, device) given_pro = None memory_index = 0 memory_data_grads = [] for epoch_i in range(epoch): #print('epoch', epoch_i) #training_data = training_data[0:100] #for i in range((len(training_data)-1)//batch_size+1): for i, samples in enumerate(memory_data): #samples = training_data[i*batch_size:(i+1)*batch_size] seed_rels = [] for item in samples: if item[0] not in seed_rels: seed_rels.append(item[0]) #start_time = time.time() ''' if len(memory_data) > 0: all_seen_data = [] for this_memory in memory_data: all_seen_data+=this_memory memory_batch = memory_data[memory_index] #memory_batch = random.sample(all_seen_data, # min(batch_size, len(all_seen_data))) #print(memory_data) scores, loss = feed_samples(model, memory_batch, loss_function, all_relations, device, reverse_model) if to_update_reverse: reverse_optimiser.step() else: optimizer.step() memory_index = (memory_index+1)%len(memory_data) ''' scores, loss = feed_samples(model, samples, loss_function, all_relations, device, reverse_model) #memory_que_embed[i], memory_rel_embed[i]) #end_time = time.time() #print('forward time:', end_time - start_time) sample_grad = copy_grad_data(model) if to_update_reverse: reverse_optimiser.step() else: optimizer.step() del scores del loss ''' acc=evaluate_model(model, valid_data, batch_size, all_relations, device) if acc > best_acc: torch.save(model, model_path) best_model = torch.load(model_path) return best_model ''' #acc_aft=evaluate_model(model, valid_data, batch_size, all_relations, device) #return model, max(0, acc_aft-acc_pre) if to_update_reverse: return reverse_model, 0 else: return model, 0
def train(training_data, valid_data, vocabulary, embedding_dim, hidden_dim, device, batch_size, lr, model_path, embedding, all_relations, model=None, epoch=100, memory_data=[], loss_margin=0.5, past_fisher=None, rel_samples=[], relation_frequences=[], rel_embeds=None, rel_ques_cand=None, rel_acc_diff=None, all_seen_rels=None, update_rel_embed=None, reverse_model=None, memory_que_embed=[], memory_rel_embed=[], to_update_reverse=False): if model is None: torch.manual_seed(100) model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary), np.array(embedding), 1, device) loss_function = nn.MarginRankingLoss(loss_margin) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr) reverse_optimiser = torch.optim.Adam(reverse_model.parameters(), lr=lr) best_acc = 0 #acc_pre=evaluate_model(model, valid_data, batch_size, all_relations, device) given_pro = None memory_index = 0 memory_data_grads = [] for epoch_i in range(epoch): #print('epoch', epoch_i) #training_data = training_data[0:100] for i in range((len(training_data) - 1) // batch_size + 1): samples = training_data[i * batch_size:(i + 1) * batch_size] seed_rels = [] for item in samples: if item[0] not in seed_rels: seed_rels.append(item[0]) ''' for item_cand in item[1]: if item_cand not in seed_rels: seed_rels.append(item_cand) ''' if len(rel_samples) > 0: memory_data = sample_constrains(rel_samples, relation_frequences, rel_embeds, seed_rels, rel_ques_cand, rel_acc_diff, given_pro, all_seen_rels) ''' to_train_mem = memory_data if len(memory_data) > num_constrain: to_train_mem = random.sample(memory_data, num_constrain) memory_data_grads = get_grads_memory_data(model, to_train_mem, loss_function, all_relations, device, reverse_model, memory_que_embed, memory_rel_embed) ''' #print(memory_data_grads) #start_time = time.time() if len(memory_data) > 0: all_seen_data = [] for this_memory in memory_data: all_seen_data += this_memory memory_batch = memory_data[memory_index] #memory_batch = random.sample(all_seen_data, # min(batch_size, len(all_seen_data))) #print(memory_data) scores, loss = feed_samples(model, memory_batch, loss_function, all_relations, device, reverse_model) if to_update_reverse: reverse_optimiser.step() else: optimizer.step() memory_index = (memory_index + 1) % len(memory_data) scores, loss = feed_samples(model, samples, loss_function, all_relations, device, reverse_model) #end_time = time.time() #print('forward time:', end_time - start_time) sample_grad = copy_grad_data(model) if len(memory_data_grads) > 0: #if not check_constrain(memory_data_grads, sample_grad): if True: project2cone2(sample_grad, memory_data_grads) if past_fisher is None: grad_params = get_grad_params(model) grad_dims = [ param.data.numel() for param in grad_params ] overwrite_grad(grad_params, sample_grad, grad_dims) if past_fisher is not None: sample_grad = rescale_grad(sample_grad, past_fisher) grad_params = get_grad_params(model) grad_dims = [param.data.numel() for param in grad_params] overwrite_grad(grad_params, sample_grad, grad_dims) if to_update_reverse: reverse_optimiser.step() else: optimizer.step() #optimizer.step() if (epoch_i % 5 == 0) and len(relation_frequences) > 0 and False: update_rel_embed(model, all_seen_rels, all_relations, rel_embeds) samples = list(relation_frequences.keys()) #return random.sample(samples, min(len(samples), num_samples)) sample_embeds = torch.from_numpy( np.asarray([rel_embeds[i] for i in samples])) #seed_rel_embeds = torch.from_numpy(np.asarray( # [rel_embeds[i] for i in seed_rels])).to(device) sample_embeds_np = sample_embeds.cpu().double().numpy() #given_pro = kmeans_pro(sample_embeds_np, samples, num_constrain) if epoch_i % 5 == 0 and False: update_rel_embed(model, all_seen_rels, all_relations, rel_embeds) #update_rel_cands(memory_data, all_seen_rels, rel_embeds) del scores del loss ''' acc=evaluate_model(model, valid_data, batch_size, all_relations, device) if acc > best_acc: torch.save(model, model_path) best_model = torch.load(model_path) return best_model ''' #acc_aft=evaluate_model(model, valid_data, batch_size, all_relations, device) #return model, max(0, acc_aft-acc_pre) if to_update_reverse: return reverse_model, 0 else: return model, 0
import numpy as np from model import SimilarityModel a = SimilarityModel(10, 10, 100, np.random.rand(100, 10), 1, 'cpu') for param in a.named_parameters(): print(param)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--cuda_id', default=0, type=int, help='cuda device index, -1 means use cpu') parser.add_argument('--train_file', default='dataset/training_data.txt', help='train file') parser.add_argument('--valid_file', default='dataset/val_data.txt', help='valid file') parser.add_argument('--test_file', default='dataset/val_data.txt', help='test file') parser.add_argument('--relation_file', default='dataset/relation_name.txt', help='relation name file') parser.add_argument('--glove_file', default='dataset/glove.6B.300d.txt', help='glove embedding file') parser.add_argument('--kl_dist_file', default='dataset/kl_dist_ht.json', help='glove embedding file') parser.add_argument('--embedding_dim', default=300, type=int, help='word embeddings dimensional') parser.add_argument('--hidden_dim', default=200, type=int, help='BiLSTM hidden dimensional') parser.add_argument( '--task_arrange', default='cluster_by_glove_embedding', help='task arrangement method, e.g. cluster_by_glove_embedding, random' ) parser.add_argument('--rel_encode', default='glove', help='relation encode method') parser.add_argument( '--meta_method', default='reptile', help='meta learning method, maml and reptile can be choose') parser.add_argument('--batch_size', default=50, type=float, help='Reptile inner loop batch size') parser.add_argument('--task_num', default=10, type=int, help='number of tasks') parser.add_argument( '--train_instance_num', default=200, type=int, help='number of instances for one relation, -1 means all.') parser.add_argument('--loss_margin', default=0.5, type=float, help='loss margin setting') parser.add_argument('--outside_epoch', default=200, type=float, help='task level epoch') parser.add_argument('--early_stop', default=20, type=float, help='task level epoch') parser.add_argument('--step_size', default=0.6, type=float, help='step size Epsilon') parser.add_argument('--learning_rate', default=2e-3, type=float, help='learning rate') parser.add_argument('--random_seed', default=226, type=int, help='random seed') parser.add_argument('--task_memory_size', default=50, type=int, help='number of samples for each task') parser.add_argument( '--memory_select_method', default='vec_cluster', help= 'the method of sample memory data, e.g. vec_cluster, random, difficulty' ) parser.add_argument( '--curriculum_rel_num', default=3, help= 'curriculum learning relation sampled number for current training relation' ) parser.add_argument( '--curriculum_instance_num', default=5, help= 'curriculum learning instance sampled number for a sampled relation') opt = parser.parse_args() print(opt) random.seed(opt.random_seed) torch.manual_seed(opt.random_seed) np.random.seed(opt.random_seed) np.random.RandomState(opt.random_seed) device = torch.device(( 'cuda:%d' % opt.cuda_id ) if torch.cuda.is_available() and opt.cuda_id >= 0 else 'cpu') # do following process split_train_data, train_data_dict, split_test_data, test_data_dict, split_valid_data, valid_data_dict, \ relation_numbers, rel_features, vocabulary, embedding = \ load_data(opt.train_file, opt.valid_file, opt.test_file, opt.relation_file, opt.glove_file, opt.embedding_dim, opt.task_arrange, opt.rel_encode, opt.task_num, opt.train_instance_num) # kl similarity of the joint distribution of head and tail kl_dist_ht = read_json(opt.kl_dist_file) # tmp = [[0, 1, 2, 3], [1, 0, 4, 6], [2, 4, 0, 5], [3, 6, 5, 0]] sorted_sililarity_index = np.argsort(-np.asarray(kl_dist_ht), axis=1) + 1 # prepare model inner_model = SimilarityModel(opt.embedding_dim, opt.hidden_dim, len(vocabulary), np.array(embedding), 1, device) memory_data = [] memory_question_embed = [] memory_relation_embed = [] sequence_results = [] result_whole_test = [] seen_relations = [] all_seen_relations = [] memory_index = 0 for task_index in range(opt.task_num): # outside loop # reptile start model parameters pi weights_before = deepcopy(inner_model.state_dict()) train_task = split_train_data[task_index] test_task = split_test_data[task_index] valid_task = split_valid_data[task_index] # collect seen relations for data_item in train_task: if data_item[0] not in seen_relations: seen_relations.append(data_item[0]) # remove unseen relations current_train_data = remove_unseen_relation(train_task, seen_relations) current_valid_data = remove_unseen_relation(valid_task, seen_relations) current_test_data = [] for previous_task_id in range(task_index + 1): current_test_data.append( remove_unseen_relation(split_test_data[previous_task_id], seen_relations)) # train inner_model loss_function = nn.MarginRankingLoss(opt.loss_margin) inner_model = inner_model.to(device) optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate) t = tqdm(range(opt.outside_epoch)) best_valid_acc = 0.0 early_stop = 0 best_checkpoint = '' for epoch in t: batch_num = (len(current_train_data) - 1) // opt.batch_size + 1 total_loss = 0.0 for batch in range(batch_num): batch_train_data = current_train_data[batch * opt.batch_size:(batch + 1) * opt.batch_size] if len(memory_data) > 0: all_seen_data = [] for one_batch_memory in memory_data: all_seen_data += one_batch_memory memory_batch = memory_data[memory_index] batch_train_data.extend(memory_batch) # scores, loss = feed_samples(inner_model, memory_batch, loss_function, relation_numbers, device) # optimizer.step() memory_index = (memory_index + 1) % len(memory_data) # random.shuffle(batch_train_data) # curriculum before batch_train if task_index > 0: current_train_rel = batch_train_data[0][0] current_rel_similarity_sorted_index = sorted_sililarity_index[ current_train_rel + 1] seen_relation_sorted_index = [] for rel in current_rel_similarity_sorted_index: if rel in seen_relations: seen_relation_sorted_index.append(rel) curriculum_rel_list = [] if opt.curriculum_rel_num >= len( seen_relation_sorted_index): curriculum_rel_list = seen_relation_sorted_index[:] else: step = len(seen_relation_sorted_index ) // opt.curriculum_rel_num for i in range(0, len(seen_relation_sorted_index), step): curriculum_rel_list.append( seen_relation_sorted_index[i]) curriculum_instance_list = [] for curriculum_rel in curriculum_rel_list: curriculum_instance_list.extend( random.sample(train_data_dict[curriculum_rel], opt.curriculum_instance_num)) curriculum_instance_list = remove_unseen_relation( curriculum_instance_list, seen_relations) # optimizer.zero_grad() scores, loss = feed_samples(inner_model, curriculum_instance_list, loss_function, relation_numbers, device) # loss.backward() optimizer.step() scores, loss = feed_samples(inner_model, batch_train_data, loss_function, relation_numbers, device) optimizer.step() total_loss += loss # valid test valid_acc = evaluate_model(inner_model, current_valid_data, opt.batch_size, relation_numbers, device) # checkpoint checkpoint = { 'net_state': inner_model.state_dict(), 'optimizer': optimizer.state_dict() } if valid_acc > best_valid_acc: best_checkpoint = './checkpoint/checkpoint_task%d_epoch%d.pth.tar' % ( task_index, epoch) torch.save(checkpoint, best_checkpoint) best_valid_acc = valid_acc early_stop = 0 else: early_stop += 1 # print() t.set_description('Task %i Epoch %i' % (task_index + 1, epoch + 1)) t.set_postfix(loss=total_loss.item(), valid_acc=valid_acc, early_stop=early_stop, best_checkpoint=best_checkpoint) t.update(1) if early_stop >= opt.early_stop: # 已经充分训练了 break t.close() print('Load best check point from %s' % best_checkpoint) checkpoint = torch.load(best_checkpoint) weights_after = checkpoint['net_state'] # weights_after = inner_model.state_dict() # 经过inner_epoch轮次的梯度更新后weights # if task_index == opt.task_num - 1: # outer_step_size = opt.step_size * (1 - 5 / opt.task_num) # else: outer_step_size = math.sqrt(opt.step_size * (1 - task_index / opt.task_num)) # outer_step_size = opt.step_size / opt.task_num # outer_step_size = 0.4 inner_model.load_state_dict({ name: weights_before[name] + (weights_after[name] - weights_before[name]) * outer_step_size for name in weights_before }) # 用memory进行训练: # for i in range(5): # for one_batch_memory in memory_data: # scores, loss = feed_samples(inner_model, one_batch_memory, loss_function, relation_numbers, device) # optimizer.step() results = [ evaluate_model(inner_model, test_data, opt.batch_size, relation_numbers, device) for test_data in current_test_data ] # 使用current model和alignment model对test data进行一个预测 # sample memory from current_train_data if opt.memory_select_method == 'random': memory_data.append( random_select_data(current_train_data, int(opt.task_memory_size / results[-1]))) elif opt.memory_select_method == 'vec_cluster': memory_data.append( select_data(inner_model, current_train_data, int(opt.task_memory_size / results[-1]), relation_numbers, opt.batch_size, device) ) # memorydata是一个list,list中的每个元素都是一个包含selected_num个sample的list elif opt.memory_select_method == 'difficulty': memory_data.append() print_list(results) avg_result = sum(results) / len(results) test_set_size = [len(testdata) for testdata in current_test_data] whole_result = sum([ results[i] * test_set_size[i] for i in range(len(current_test_data)) ]) / sum(test_set_size) print('test_set_size: [%s]' % ', '.join([str(size) for size in test_set_size])) print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result)) print('test_all:') for epoch in range(10): current_test_data = [] for previous_task_id in range(opt.task_num): current_test_data.append( remove_unseen_relation(split_test_data[previous_task_id], seen_relations)) loss_function = nn.MarginRankingLoss(opt.loss_margin) optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate) optimizer.zero_grad() for one_batch_memory in memory_data: scores, loss = feed_samples(inner_model, one_batch_memory, loss_function, relation_numbers, device) optimizer.step() results = [ evaluate_model(inner_model, test_data, opt.batch_size, relation_numbers, device) for test_data in current_test_data ] print(results) avg_result = sum(results) / len(results) test_set_size = [len(testdata) for testdata in current_test_data] whole_result = sum([ results[i] * test_set_size[i] for i in range(len(current_test_data)) ]) / sum(test_set_size) print('test_set_size: [%s]' % ', '.join([str(size) for size in test_set_size])) print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result))
def train(training_data, valid_data, vocabulary, embedding_dim, hidden_dim, device, batch_size, lr, model_path, embedding, all_relations, model=None, epoch=100, grad_means=[], grad_fishers=[], loss_margin=2.0): if model is None: torch.manual_seed(100) model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary), np.array(embedding), 1, device) loss_function = nn.MarginRankingLoss(loss_margin) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr) best_acc = 0 for epoch_i in range(epoch): #print('epoch', epoch_i) #training_data = training_data[0:100] for i in range((len(training_data) - 1) // batch_size + 1): samples = training_data[i * batch_size:(i + 1) * batch_size] questions, relations, relation_set_lengths = process_samples( samples, all_relations, device) #print('got data') ranked_questions, reverse_question_indexs = \ ranking_sequence(questions) ranked_relations, reverse_relation_indexs =\ ranking_sequence(relations) question_lengths = [len(question) for question in ranked_questions] relation_lengths = [len(relation) for relation in ranked_relations] #print(ranked_questions) pad_questions = torch.nn.utils.rnn.pad_sequence(ranked_questions) pad_relations = torch.nn.utils.rnn.pad_sequence(ranked_relations) #print(pad_questions) pad_questions = pad_questions.to(device) pad_relations = pad_relations.to(device) #print(pad_questions) model.zero_grad() model.init_hidden(device, sum(relation_set_lengths)) all_scores = model(pad_questions, pad_relations, device, reverse_question_indexs, reverse_relation_indexs, question_lengths, relation_lengths) all_scores = all_scores.to('cpu') pos_scores = [] neg_scores = [] start_index = 0 for length in relation_set_lengths: pos_scores.append(all_scores[start_index].expand(length - 1)) neg_scores.append(all_scores[start_index + 1:start_index + length]) start_index += length pos_scores = torch.cat(pos_scores) neg_scores = torch.cat(neg_scores) loss = loss_function( pos_scores, neg_scores, torch.ones( sum(relation_set_lengths) - len(relation_set_lengths))) loss = loss.sum() #loss.to(device) #print(loss) for i in range(len(grad_means)): grad_mean = grad_means[i] grad_fisher = grad_fishers[i] #print(param_loss(model, grad_mean, grad_fisher, p_lambda)) loss += param_loss(model, grad_mean, grad_fisher, p_lambda).to('cpu') loss.backward() optimizer.step() ''' acc=evaluate_model(model, valid_data, batch_size, all_relations, device) if acc > best_acc: torch.save(model, model_path) best_model = torch.load(model_path) return best_model ''' return model