Beispiel #1
0
def train(training_data,
          valid_data,
          vocabulary,
          embedding_dim,
          hidden_dim,
          device,
          batch_size,
          lr,
          embedding,
          all_relations,
          model=None,
          epoch=100,
          memory_data=[],
          loss_margin=0.5,
          alignment_model=None):
    if model is None:
        torch.manual_seed(100)
        model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary),
                                np.array(embedding), 1, device)
    loss_function = nn.MarginRankingLoss(loss_margin)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    best_acc = 0
    memory_index = 0
    for epoch_i in range(epoch):
        for i in range((len(training_data) - 1) // batch_size + 1):
            samples = training_data[i * batch_size:(i + 1) * batch_size]
            seed_rels = []
            for item in samples:
                if item[0] not in seed_rels:
                    seed_rels.append(item[0])

            if len(memory_data) > 0:
                all_seen_data = []
                for this_memory in memory_data:
                    all_seen_data += this_memory
                memory_batch = memory_data[memory_index]
                scores, loss = feed_samples(model, memory_batch, loss_function,
                                            all_relations, device,
                                            alignment_model)
                optimizer.step()
                memory_index = (memory_index + 1) % len(memory_data)
            scores, loss = feed_samples(model, samples, loss_function,
                                        all_relations, device, alignment_model)
            optimizer.step()
            del scores
            del loss
    return model
Beispiel #2
0
def train(training_data, valid_data, vocabulary, embedding_dim, hidden_dim,
          device, batch_size, lr, model_path, embedding, all_relations,
          model=None, epoch=100, all_seen_samples=[],
          task_memory_size=100, loss_margin=0.5, all_seen_rels=[]):
    if model is None:
        torch.manual_seed(100)
        model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary),
                                np.array(embedding), 1, device)
    loss_function = nn.MarginRankingLoss(loss_margin)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    best_acc = 0
    for epoch_i in range(epoch):
        #print('epoch', epoch_i)
        #training_data = training_data[0:100]
        for i in range((len(training_data)-1)//batch_size+1):
            memory_data = sample_memory_data(all_seen_samples, task_memory_size)
            for this_sample in memory_data:
                rel_cands = [rel for rel in all_seen_rels if rel!=this_sample[0]]
                this_sample[1] = random.sample(rel_cands,
                                               min(len(rel_cands),num_cands))
            memory_data_grads = get_grads_memory_data(model, memory_data,
                                                      loss_function,
                                                      all_relations,
                                                      device)
            #print(memory_data_grads)
            samples = training_data[i*batch_size:(i+1)*batch_size]
            feed_samples(model, samples, loss_function, all_relations, device)
            sample_grad = copy_grad_data(model)
            if len(memory_data_grads) > 0:
                if torch.matmul(memory_data_grads,
                                torch.t(sample_grad.view(1,-1))) < 0:
                    project2cone2(sample_grad, memory_data_grads)
                    grad_params = get_grad_params(model)
                    grad_dims = [param.data.numel() for param in grad_params]
                    overwrite_grad(grad_params, sample_grad, grad_dims)
            optimizer.step()
        '''
        acc=evaluate_model(model, valid_data, batch_size, all_relations, device)
        if acc > best_acc:
            torch.save(model, model_path)
    best_model = torch.load(model_path)
    return best_model
    '''
    return model
Beispiel #3
0
def main(opt):
    print(opt)
    random.seed(opt.random_seed)
    torch.manual_seed(opt.random_seed)
    np.random.seed(opt.random_seed)
    np.random.RandomState(opt.random_seed)
    start_time = time.time()
    checkpoint_dir = os.path.join(opt.checkpoint_dir, '%.f' % start_time)

    device = torch.device(('cuda:%d' % opt.cuda_id) if torch.cuda.is_available() and opt.cuda_id >= 0 else 'cpu')

    # do following process
    split_train_data, train_data_dict, split_test_data, split_valid_data, relation_numbers, rel_features, \
    split_train_relations, vocabulary, embedding = \
        load_data(opt.train_file, opt.valid_file, opt.test_file, opt.relation_file, opt.glove_file,
                  opt.embedding_dim, opt.task_arrange, opt.rel_encode, opt.task_num,
                  opt.train_instance_num, opt.dataset)
    print(split_train_relations)

    # ------------------------------------------------------------------------
    # save cluster results
    our_tasks = split_train_relations

    count = 0
    for i in our_tasks:
        count += len(i)

    portion = np.zeros(10000)
    portion = portion-1
    for i in range(len(our_tasks)):
        for j in our_tasks[i]:
            portion[j - 1] = int(i)
    np.save("dataset/tacred/CML_tacred_random.npy", np.array(portion).astype(int))
    # -------------------------------------------------------------------------

    print('\n'.join(
        ['Task %d\t%s' % (index, ', '.join(['%d' % rel for rel in split_train_relations[index]])) for index in
         range(len(split_train_relations))]))

    task_sequence = list(range(opt.task_num))
    if opt.random_idx:
        for i in range(opt.random_times):
            random.shuffle(task_sequence)

    offset_seq = task_sequence[-opt.sequence_index:] + task_sequence[:-opt.sequence_index]

    split_train_data = resort_list(split_train_data, offset_seq)
    split_test_data = resort_list(split_test_data, offset_seq)
    split_valid_data = resort_list(split_valid_data, offset_seq)
    split_train_relations = resort_list(split_train_relations, offset_seq)
    print('[%s]' % ', '.join(['Task %d' % idx for idx in offset_seq]))

    relid2embedidx = {}
    embedidx2relid = {}
    if opt.similarity == 'kl_similarity':
        kl_dist_ht = read_json(opt.kl_dist_file)

        sorted_similarity_index = np.argsort(np.asarray(kl_dist_ht), axis=1) + 1
    elif opt.similarity == 'glove_similarity':
        glove_embedding = []

        embed_id = 0
        for rel_id in rel_features:
            glove_embedding.append(rel_features[rel_id])
            relid2embedidx[rel_id] = embed_id
            embedidx2relid[embed_id] = rel_id
            embed_id += 1

        glove_similarity = cosine_similarity(np.asarray(glove_embedding))
        glove_dist = np.sqrt(1 - np.power(np.where(glove_similarity > 1.0, 1.0, glove_similarity), 2))
        sorted_embed_index = np.argsort(np.asarray(glove_dist), axis=1)
        sorted_similarity_index = np.zeros(sorted_embed_index.shape)
        for i in range(sorted_embed_index.shape[0]):
            for j in range(sorted_embed_index.shape[1]):
                sorted_similarity_index[i][j] = embedidx2relid[sorted_embed_index[i][j]]
    else:
        raise Exception('similarity method not implemented')

    # prepare model
    inner_model = SimilarityModel(opt.embedding_dim, opt.hidden_dim, len(vocabulary),
                                  np.array(embedding), 1, device)

    memory_data = []
    memory_pool = []
    memory_question_embed = []
    memory_relation_embed = []
    sequence_results = []
    result_whole_test = []
    seen_relations = []
    all_seen_relations = []
    rel2instance_memory = {}
    memory_index = 0
    seen_task_relations = []
    rel_embeddings = []
    for task_ix in range(opt.task_num):  # outside loop
        # reptile start model parameters pi
        weights_before = deepcopy(inner_model.state_dict())

        train_task = split_train_data[task_ix]
        test_task = split_test_data[task_ix]
        valid_task = split_valid_data[task_ix]
        train_relations = split_train_relations[task_ix]
        seen_task_relations.append(train_relations)

        # collect seen relations
        for data_item in train_task:
            if data_item[0] not in seen_relations:
                seen_relations.append(data_item[0])

        # remove unseen relations
        current_train_data = remove_unseen_relation(train_task, seen_relations, dataset=opt.dataset)
        current_valid_data = remove_unseen_relation(valid_task, seen_relations, dataset=opt.dataset)

        current_test_data = []
        for previous_task_id in range(task_ix + 1):
            current_test_data.append(
                remove_unseen_relation(split_test_data[previous_task_id], seen_relations, dataset=opt.dataset))

        for this_sample in current_train_data:
            if this_sample[0] not in all_seen_relations:
                all_seen_relations.append(this_sample[0])

        update_rel_cands(memory_data, all_seen_relations, opt.num_cands)

        # train inner_model
        loss_function = nn.MarginRankingLoss(opt.loss_margin)
        inner_model = inner_model.to(device)
        optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate)
        t = tqdm(range(opt.outside_epoch))
        best_valid_acc = 0.0
        early_stop = 0
        best_checkpoint = ''

        #
        resorted_memory_pool = []
        for epoch in t:
            batch_num = (len(current_train_data) - 1) // opt.batch_size + 1
            total_loss = 0.0
            target_rel = -1
            for batch in range(batch_num):

                batch_train_data = current_train_data[batch * opt.batch_size: (batch + 1) * opt.batch_size]

                if len(memory_data) > 0:
                    # CML
                    if target_rel == -1 or len(resorted_memory_pool) == 0:
                        target_rel = batch_train_data[0][0]
                        if opt.similarity == 'kl_similarity':
                            target_rel_sorted_index = sorted_similarity_index[target_rel - 1]
                        else:
                            target_rel_sorted_index = sorted_similarity_index[relid2embedidx[target_rel]]
                        resorted_memory_pool = resort_memory(memory_pool, target_rel_sorted_index)

                    if len(resorted_memory_pool) >= opt.task_memory_size:
                        current_memory = resorted_memory_pool[:opt.task_memory_size]
                        resorted_memory_pool = resorted_memory_pool[opt.task_memory_size + 1:]  # update rest memory
                        batch_train_data.extend(current_memory)
                    else:
                        current_memory = resorted_memory_pool[:]
                        resorted_memory_pool = []
                        batch_train_data.extend(current_memory)

                    # MLLRE
                    # all_seen_data = []
                    # for one_batch_memory in memory_data:
                    #     all_seen_data += one_batch_memory
                    #
                    # memory_batch = memory_data[memory_index]
                    # batch_train_data.extend(memory_batch)
                    # scores, loss = feed_samples(inner_model, memory_batch, loss_function, relation_numbers, device)
                    # optimizer.step()
                    # memory_index = (memory_index+1) % len(memory_data)

                if len(rel2instance_memory) > 0:  # from the second task, this will not be empty
                    if opt.is_curriculum_train == 'Y':
                        current_train_rel = batch_train_data[0][0]
                        current_rel_similarity_sorted_index = sorted_similarity_index[current_train_rel + 1]
                        seen_relation_sorted_index = []
                        for rel in current_rel_similarity_sorted_index:
                            if rel in rel2instance_memory.keys():
                                seen_relation_sorted_index.append(rel)

                        curriculum_rel_list = []
                        if opt.sampled_rel_num >= len(seen_relation_sorted_index):
                            curriculum_rel_list = seen_relation_sorted_index[:]
                        else:
                            step = len(seen_relation_sorted_index) // opt.sampled_rel_num
                            for i in range(0, len(seen_relation_sorted_index), step):
                                curriculum_rel_list.append(seen_relation_sorted_index[i])

                        # curriculum select relation
                        instance_list = []
                        for sampled_relation in curriculum_rel_list:
                            if opt.mini_batch_split == 'Y':
                                instance_list.append(rel2instance_memory[sampled_relation])
                            else:
                                instance_list.extend(rel2instance_memory[sampled_relation])
                    else:
                        # randomly select relation
                        instance_list = []
                        random_relation_list = random.sample(list(rel2instance_memory.keys()),
                                                             min(opt.sampled_rel_num, len(rel2instance_memory)))
                        for sampled_relation in random_relation_list:
                            if opt.mini_batch_split == 'Y':
                                instance_list.append(rel2instance_memory[sampled_relation])
                            else:
                                instance_list.extend(rel2instance_memory[sampled_relation])

                    if opt.mini_batch_split == 'Y':
                        for one_batch_instance in instance_list:
                            scores, loss = feed_samples(inner_model, one_batch_instance, loss_function,
                                                        relation_numbers, device, all_seen_relations)
                            optimizer.step()
                    else:
                        scores, loss = feed_samples(inner_model, instance_list, loss_function, relation_numbers, device,
                                                    all_seen_relations)
                        optimizer.step()

                scores, loss = feed_samples(inner_model, batch_train_data, loss_function, relation_numbers, device,
                                            all_seen_relations)
                optimizer.step()
                total_loss += loss

            # valid test
            valid_acc = evaluate_model(inner_model, current_valid_data, opt.batch_size, relation_numbers, device)
            # checkpoint
            checkpoint = {'net_state': inner_model.state_dict(), 'optimizer': optimizer.state_dict()}
            if valid_acc > best_valid_acc:
                best_checkpoint = '%s/checkpoint_task%d_epoch%d.pth.tar' % (checkpoint_dir, task_ix + 1, epoch)
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                torch.save(checkpoint, best_checkpoint)
                best_valid_acc = valid_acc
                early_stop = 0
            else:
                early_stop += 1

            # print()
            t.set_description('Task %i Epoch %i' % (task_ix + 1, epoch + 1))
            t.set_postfix(loss=total_loss.item(), valid_acc=valid_acc, early_stop=early_stop,
                          best_checkpoint=best_checkpoint)
            t.update(1)

            if early_stop >= opt.early_stop and task_ix != 0:
                # convergence
                break

            if task_ix == 0 and early_stop >= 40:
                break
        t.close()
        print('Load best check point from %s' % best_checkpoint)
        checkpoint = torch.load(best_checkpoint)

        weights_after = checkpoint['net_state']
        if opt.outer_step_formula == 'fixed':
            outer_step_size = opt.step_size
        elif opt.outer_step_formula == 'linear':
            outer_step_size = opt.step_size * (1 - task_ix / opt.task_num)
        elif opt.outer_step_formula == 'square_root':
            outer_step_size = math.sqrt(opt.step_size * (1 - task_ix / opt.task_num))
        # outer_step_size = 0.4
        inner_model.load_state_dict(
            {name: weights_before[name] + (weights_after[name] - weights_before[name]) * outer_step_size
             for name in weights_before})

        results = [evaluate_model(inner_model, test_data, opt.batch_size, relation_numbers, device)
                   for test_data in current_test_data]

        # sample memory from current_train_data
        if opt.memory_select_method == 'select_for_relation':
            # sample instance for one relation
            for rel in train_relations:
                rel_items = remove_unseen_relation(train_data_dict[rel], seen_relations, dataset=opt.dataset)
                rel_memo = select_data(inner_model, rel_items, int(opt.sampled_instance_num),
                                       relation_numbers, opt.batch_size, device)
                rel2instance_memory[rel] = rel_memo

        if opt.memory_select_method == 'select_for_task':
            # sample instance for one Task
            rel_instance_num = math.ceil(opt.sampled_instance_num_total / len(train_relations))
            for rel in train_relations:
                rel_items = remove_unseen_relation(train_data_dict[rel], seen_relations, dataset=opt.dataset)
                rel_memo = select_data(inner_model, rel_items, rel_instance_num,
                                       relation_numbers, opt.batch_size, device)
                rel2instance_memory[rel] = rel_memo

        if opt.task_memory_size > 0:
            # sample memory from current_train_data
            if opt.memory_select_method == 'random':
                memory_data.append(random_select_data(current_train_data, int(opt.task_memory_size)))
            elif opt.memory_select_method == 'vec_cluster':
                selected_memo = select_data(inner_model, current_train_data, int(opt.task_memory_size),
                                            relation_numbers, opt.batch_size, device)
                memory_data.append(selected_memo)  # memorydata-list
                memory_pool.extend(selected_memo)
            elif opt.memory_select_method == 'difficulty':
                memory_data.append()

        print_list(results)
        avg_result = sum(results) / len(results)
        test_set_size = [len(testdata) for testdata in current_test_data]
        whole_result = sum([results[i] * test_set_size[i] for i in range(len(current_test_data))]) / sum(test_set_size)
        print('test_set_size: [%s]' % ', '.join([str(size) for size in test_set_size]))
        print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result))

    print('test_all:')
    result_total_for_avg = []
    result_total_for_whole = []
    for epoch in range(10):
        current_test_data = []
        for previous_task_id in range(opt.task_num):
            current_test_data.append(
                remove_unseen_relation(split_test_data[previous_task_id], seen_relations, dataset=opt.dataset))

        loss_function = nn.MarginRankingLoss(opt.loss_margin)
        optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate)
        optimizer.zero_grad()
        for one_batch_memory in memory_data:
            scores, loss = feed_samples(inner_model, one_batch_memory, loss_function, relation_numbers, device,
                                        all_seen_relations)
            optimizer.step()
        results = [evaluate_model(inner_model, test_data, opt.batch_size, relation_numbers, device)
                   for test_data in current_test_data]
        print(results)

        avg_result = sum(results) / len(results)
        test_set_size = [len(testdata) for testdata in current_test_data]
        whole_result = sum([results[i] * test_set_size[i] for i in range(len(current_test_data))]) / sum(test_set_size)

        print('test_set_size: [%s]' % ', '.join([str(size) for size in test_set_size]))
        print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result))
        result_total_for_avg.append(results)
        result_total_for_whole.append(whole_result)

    # clean saved parameters
    files = os.listdir(checkpoint_dir)
    for weigths_file in files:
        os.remove(os.path.join(checkpoint_dir, weigths_file))
    os.removedirs(checkpoint_dir)

    # -----------------------------------------------------------
    # 输出结果
    avg_total = np.mean(np.array(result_total_for_avg), 1)
    avg_mean, avg_interval = interval(avg_total)
    whole_mean, whole_interval = interval(np.array(result_total_for_whole))
    result_total = {"avg_acc": result_total_for_avg, "whole_acc": result_total_for_whole,
                    "avg_mean": avg_mean, "avg_interval": avg_interval.tolist(),
                    "whole_mean": whole_mean, "whole_interval": whole_interval.tolist()}
    print(result_total)

    with open(opt.result_file, "w") as file_out:
        json.dump(result_total, file_out)
Beispiel #4
0
def main(opt):

    print(opt)
    # print('线性outer step formula,0.6 step size, 每task聚类取50个memo')
    random.seed(opt.random_seed)
    torch.manual_seed(opt.random_seed)
    np.random.seed(opt.random_seed)
    np.random.RandomState(opt.random_seed)
    start_time = time.time()
    checkpoint_dir = os.path.join(opt.checkpoint_dir, '%.f' % start_time)

    device = torch.device((
        'cuda:%d' % opt.cuda_id
    ) if torch.cuda.is_available() and opt.cuda_id >= 0 else 'cpu')

    # do following process
    split_train_data, train_data_dict, split_test_data, split_valid_data, relation_numbers, rel_features, \
    split_train_relations, vocabulary, embedding = \
        load_data(opt.train_file, opt.valid_file, opt.test_file, opt.relation_file, opt.glove_file,
                  opt.embedding_dim, opt.task_arrange, opt.rel_encode, opt.task_num,
                  opt.train_instance_num, opt.dataset)
    print('\n'.join([
        'Task %d\t%s' %
        (index, ', '.join(['%d' % rel
                           for rel in split_train_relations[index]]))
        for index in range(len(split_train_relations))
    ]))

    # offset tasks
    # split_train_data = offset_list(split_train_data, opt.task_offset)
    # split_test_data = offset_list(split_test_data, opt.task_offset)
    # split_valid_data = offset_list(split_valid_data, opt.task_offset)
    # task_sq = [None] * len(split_train_relations)
    # for i in range(len(split_train_relations)):
    #     task_sq[(i + opt.task_offset) % len(split_train_relations)] = i
    # print('[%s]' % ', '.join(['Task %d' % i for i in task_sq]))

    # insert 6th-task
    # task_index = [[6, 0, 1, 2, 3, 4, 5, 7, 8, 9],
    #               [0, 6, 1, 2, 3, 4, 5, 7, 8, 9],
    #               [0, 1, 6, 2, 3, 4, 5, 7, 8, 9],
    #               [0, 1, 2, 6, 3, 4, 5, 7, 8, 9],
    #               [0, 1, 2, 3, 6, 4, 5, 7, 8, 9],
    #               [0, 1, 2, 3, 4, 6, 5, 7, 8, 9],
    #               [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
    #               [0, 1, 2, 3, 4, 5, 7, 6, 8, 9],
    #               [0, 1, 2, 3, 4, 5, 7, 8, 6, 9],
    #               [0, 1, 2, 3, 4, 5, 7, 8, 9, 6]]

    # task_sequence = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
    #               [9, 0, 1, 2, 3, 4, 5, 6, 7, 8],
    #               [8, 9, 0, 1, 2, 3, 4, 5, 6, 7],
    #               [7, 8, 9, 0, 1, 2, 3, 4, 5, 6],
    #               [6, 7, 8, 9, 0, 1, 2, 3, 4, 5],
    #               [5, 6, 7, 8, 9, 0, 1, 2, 3, 4],
    #               [4, 5, 6, 7, 8, 9, 0, 1, 2, 3],
    #               [3, 4, 5, 6, 7, 8, 9, 0, 1, 2],
    #               [2, 3, 4, 5, 6, 7, 8, 9, 0, 1],
    #               [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]]
    task_sequence = list(range(opt.task_num))
    if opt.random_idx:
        for i in range(opt.random_times):
            random.shuffle(task_sequence)

    offset_seq = task_sequence[
        -opt.sequence_index:] + task_sequence[:-opt.sequence_index]
    # task_sequence = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
    #                  [9, 0, 1, 2, 3, 4, 5, 6, 7, 8],
    #                  [8, 9, 0, 1, 2, 3, 4, 5, 6, 7],
    #                  [7, 8, 9, 0, 1, 2, 3, 4, 5, 6],
    #                  [6, 7, 8, 9, 0, 1, 2, 3, 4, 5],
    #                  [5, 6, 7, 8, 9, 0, 1, 2, 3, 4],
    #                  [4, 5, 6, 7, 8, 9, 0, 1, 2, 3],
    #                  [3, 4, 5, 6, 7, 8, 9, 0, 1, 2],
    #                  [2, 3, 4, 5, 6, 7, 8, 9, 0, 1],
    #                  [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]]
    #
    # if opt.random_idx:
    #     random_idx = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    #     for i in range(opt.random_times):
    #         random.shuffle(random_idx)
    #     print(str(random_idx))
    #
    #     for i in range(len(task_sequence)):
    #         new_sq = [0] * len(task_sequence[i])
    #         for j in range(len(task_sequence[i])):
    #             new_sq[j] = random_idx[task_sequence[i][j]]
    #         task_sequence[i] = new_sq

    split_train_data = resort_list(split_train_data, offset_seq)
    split_test_data = resort_list(split_test_data, offset_seq)
    split_valid_data = resort_list(split_valid_data, offset_seq)
    split_train_relations = resort_list(split_train_relations, offset_seq)
    print('[%s]' % ', '.join(['Task %d' % idx for idx in offset_seq]))

    relid2embedidx = {}
    embedidx2relid = {}
    if opt.similarity == 'kl_similarity':
        kl_dist_ht = read_json(opt.kl_dist_file)

        # tmp = [[0, 1, 2, 3], [1, 0, 4, 6], [2, 4, 0, 5], [3, 6, 5, 0]]
        sorted_similarity_index = np.argsort(np.asarray(kl_dist_ht),
                                             axis=1) + 1
    elif opt.similarity == 'glove_similarity':
        glove_embedding = []

        embed_id = 0
        for rel_id in rel_features:
            glove_embedding.append(rel_features[rel_id])
            relid2embedidx[rel_id] = embed_id
            embedidx2relid[embed_id] = rel_id
            embed_id += 1

        glove_similarity = cosine_similarity(np.asarray(glove_embedding))
        glove_dist = np.sqrt(1 - np.power(
            np.where(glove_similarity > 1.0, 1.0, glove_similarity), 2))
        sorted_embed_index = np.argsort(np.asarray(glove_dist), axis=1)
        sorted_similarity_index = np.zeros(sorted_embed_index.shape)
        for i in range(sorted_embed_index.shape[0]):
            for j in range(sorted_embed_index.shape[1]):
                sorted_similarity_index[i][j] = embedidx2relid[
                    sorted_embed_index[i][j]]
        # print()
        # for i in range(1, len(rel_features) + 1):
        #     rel_embed = rel_features[i]  # 1-80
        #     glove_embedding.append(rel_embed)  # index 0-79, total 80
        #
        # glove_similarity = cosine_similarity(np.asarray(glove_embedding))
        # glove_dist = np.sqrt(1 - np.power(np.where(glove_similarity > 1.0, 1.0, glove_similarity), 2))
        # sorted_similarity_index = np.argsort(np.asarray(glove_dist), axis=1) + 1
    else:
        raise Exception('similarity method not implemented')

    # prepare model
    inner_model = SimilarityModel(opt.embedding_dim, opt.hidden_dim,
                                  len(vocabulary), np.array(embedding), 1,
                                  device)

    memory_data = []
    memory_pool = []
    memory_question_embed = []
    memory_relation_embed = []
    sequence_results = []
    result_whole_test = []
    seen_relations = []
    all_seen_relations = []
    rel2instance_memory = {}
    memory_index = 0
    seen_task_relations = []
    rel_embeddings = []
    for task_ix in range(opt.task_num):  # outside loop
        # reptile start model parameters pi
        weights_before = deepcopy(inner_model.state_dict())

        train_task = split_train_data[task_ix]
        test_task = split_test_data[task_ix]
        valid_task = split_valid_data[task_ix]
        train_relations = split_train_relations[task_ix]
        seen_task_relations.append(train_relations)

        # collect seen relations
        for data_item in train_task:
            if data_item[0] not in seen_relations:
                seen_relations.append(data_item[0])

        # remove unseen relations
        current_train_data = remove_unseen_relation(train_task,
                                                    seen_relations,
                                                    dataset=opt.dataset)
        current_valid_data = remove_unseen_relation(valid_task,
                                                    seen_relations,
                                                    dataset=opt.dataset)
        current_test_data = []
        for previous_task_id in range(task_ix + 1):
            current_test_data.append(
                remove_unseen_relation(split_test_data[previous_task_id],
                                       seen_relations,
                                       dataset=opt.dataset))

        for this_sample in current_train_data:
            if this_sample[0] not in all_seen_relations:
                all_seen_relations.append(this_sample[0])

        update_rel_cands(memory_data, all_seen_relations, opt.num_cands)

        # train inner_model
        loss_function = nn.MarginRankingLoss(opt.loss_margin)
        inner_model = inner_model.to(device)
        optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate)
        t = tqdm(range(opt.outside_epoch))
        best_valid_acc = 0.0
        early_stop = 0
        best_checkpoint = ''

        #
        resorted_memory_pool = []
        for epoch in t:
            batch_num = (len(current_train_data) - 1) // opt.batch_size + 1
            total_loss = 0.0
            target_rel = -1
            for batch in range(batch_num):

                batch_train_data = current_train_data[batch *
                                                      opt.batch_size:(batch +
                                                                      1) *
                                                      opt.batch_size]

                if len(memory_data) > 0:
                    # curriculum select and organize memory
                    if target_rel == -1 or len(resorted_memory_pool) == 0:
                        target_rel = batch_train_data[0][0]
                        if opt.similarity == 'kl_similarity':
                            target_rel_sorted_index = sorted_similarity_index[
                                target_rel - 1]
                        else:
                            target_rel_sorted_index = sorted_similarity_index[
                                relid2embedidx[target_rel]]
                        resorted_memory_pool = resort_memory(
                            memory_pool, target_rel_sorted_index)

                    if len(resorted_memory_pool) >= opt.task_memory_size:
                        current_memory = resorted_memory_pool[:opt.
                                                              task_memory_size]
                        resorted_memory_pool = resorted_memory_pool[
                            opt.task_memory_size + 1:]  # 更新剩余的memory
                        batch_train_data.extend(current_memory)
                    else:
                        current_memory = resorted_memory_pool[:]
                        resorted_memory_pool = []  # 更新剩余的memory
                        batch_train_data.extend(current_memory)

                    # 淘汰的做法
                    # if len(resorted_memory_pool) != 0:
                    #     current_memory = resorted_memory_pool[:opt.task_memory_size]
                    #     resorted_memory_pool = resorted_memory_pool[opt.task_memory_size + 1:]  # 更新剩余的memory
                    #     batch_train_data.extend(current_memory)
                    # else:
                    #     target_rel = batch_train_data[0][0]
                    #     target_rel_sorted_index = sorted_similarity_index[target_rel - 1]
                    #     resorted_memory_pool = resort_memory(memory_pool, target_rel_sorted_index)

                    # MLLRE的做法
                    # all_seen_data = []
                    # for one_batch_memory in memory_data:
                    #     all_seen_data += one_batch_memory
                    #
                    # memory_batch = memory_data[memory_index]
                    # batch_train_data.extend(memory_batch)
                    # scores, loss = feed_samples(inner_model, memory_batch, loss_function, relation_numbers, device)
                    # optimizer.step()
                    # memory_index = (memory_index+1) % len(memory_data)

                # random.shuffle(batch_train_data)
                if len(rel2instance_memory
                       ) > 0:  # from the second task, this will not be empty
                    if opt.is_curriculum_train == 'Y':
                        current_train_rel = batch_train_data[0][0]
                        current_rel_similarity_sorted_index = sorted_similarity_index[
                            current_train_rel + 1]
                        seen_relation_sorted_index = []
                        for rel in current_rel_similarity_sorted_index:
                            if rel in rel2instance_memory.keys():
                                seen_relation_sorted_index.append(rel)

                        curriculum_rel_list = []
                        if opt.sampled_rel_num >= len(
                                seen_relation_sorted_index):
                            curriculum_rel_list = seen_relation_sorted_index[:]
                        else:
                            step = len(seen_relation_sorted_index
                                       ) // opt.sampled_rel_num
                            for i in range(0, len(seen_relation_sorted_index),
                                           step):
                                curriculum_rel_list.append(
                                    seen_relation_sorted_index[i])

                        # curriculum select relation
                        instance_list = []
                        for sampled_relation in curriculum_rel_list:
                            if opt.mini_batch_split == 'Y':
                                instance_list.append(
                                    rel2instance_memory[sampled_relation])
                            else:
                                instance_list.extend(
                                    rel2instance_memory[sampled_relation])
                    else:
                        # randomly select relation
                        instance_list = []
                        random_relation_list = random.sample(
                            list(rel2instance_memory.keys()),
                            min(opt.sampled_rel_num, len(rel2instance_memory)))
                        for sampled_relation in random_relation_list:
                            if opt.mini_batch_split == 'Y':
                                instance_list.append(
                                    rel2instance_memory[sampled_relation])
                            else:
                                instance_list.extend(
                                    rel2instance_memory[sampled_relation])

                    if opt.mini_batch_split == 'Y':
                        for one_batch_instance in instance_list:
                            # curriculum_instance_list = remove_unseen_relation(curriculum_instance_list, seen_relations)
                            scores, loss = feed_samples(
                                inner_model, one_batch_instance, loss_function,
                                relation_numbers, device)
                            optimizer.step()
                    else:
                        # curriculum_instance_list = remove_unseen_relation(curriculum_instance_list, seen_relations)
                        scores, loss = feed_samples(inner_model, instance_list,
                                                    loss_function,
                                                    relation_numbers, device)
                        optimizer.step()

                scores, loss = feed_samples(inner_model, batch_train_data,
                                            loss_function, relation_numbers,
                                            device)
                optimizer.step()
                total_loss += loss

            # valid test
            valid_acc = evaluate_model(inner_model, current_valid_data,
                                       opt.batch_size, relation_numbers,
                                       device)
            # checkpoint
            checkpoint = {
                'net_state': inner_model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            if valid_acc > best_valid_acc:
                best_checkpoint = '%s/checkpoint_task%d_epoch%d.pth.tar' % (
                    checkpoint_dir, task_ix + 1, epoch)
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                torch.save(checkpoint, best_checkpoint)
                best_valid_acc = valid_acc
                early_stop = 0
            else:
                early_stop += 1

            # print()
            t.set_description('Task %i Epoch %i' % (task_ix + 1, epoch + 1))
            t.set_postfix(loss=total_loss.item(),
                          valid_acc=valid_acc,
                          early_stop=early_stop,
                          best_checkpoint=best_checkpoint)
            t.update(1)

            if early_stop >= opt.early_stop and task_ix != 0:
                # 已经充分训练了
                break

            if task_ix == 0 and early_stop >= 40:  # 防止大数据量的task在第一轮得不到充分训练
                break
        t.close()
        print('Load best check point from %s' % best_checkpoint)
        checkpoint = torch.load(best_checkpoint)

        weights_after = checkpoint['net_state']

        # weights_after = inner_model.state_dict()  # 经过inner_epoch轮次的梯度更新后weights
        # outer_step_size = opt.step_size * (1 - task_index / opt.task_num)  # linear schedule
        # if outer_step_size < opt.step_size * 0.5:
        #     outer_step_size = opt.step_size * 0.5

        if opt.outer_step_formula == 'fixed':
            outer_step_size = opt.step_size
        elif opt.outer_step_formula == 'linear':
            outer_step_size = opt.step_size * (1 - task_ix / opt.task_num)
        elif opt.outer_step_formula == 'square_root':
            outer_step_size = math.sqrt(opt.step_size *
                                        (1 - task_ix / opt.task_num))
        # outer_step_size = 0.4
        inner_model.load_state_dict({
            name: weights_before[name] +
            (weights_after[name] - weights_before[name]) * outer_step_size
            for name in weights_before
        })

        # 用memory进行训练:
        # for i in range(5):
        #     for one_batch_memory in memory_data:
        #         scores, loss = feed_samples(inner_model, one_batch_memory, loss_function, relation_numbers, device)
        #         optimizer.step()

        results = [
            evaluate_model(inner_model, test_data, opt.batch_size,
                           relation_numbers, device)
            for test_data in current_test_data
        ]  # 使用current model和alignment model对test data进行一个预测

        # sample memory from current_train_data
        if opt.memory_select_method == 'select_for_relation':
            # 每个关系sample k个
            for rel in train_relations:
                rel_items = remove_unseen_relation(train_data_dict[rel],
                                                   seen_relations,
                                                   dataset=opt.dataset)
                rel_memo = select_data(inner_model, rel_items,
                                       int(opt.sampled_instance_num),
                                       relation_numbers, opt.batch_size,
                                       device)
                rel2instance_memory[rel] = rel_memo

        if opt.memory_select_method == 'select_for_task':
            # 为每个task sample k个
            rel_instance_num = math.ceil(opt.sampled_instance_num_total /
                                         len(train_relations))
            for rel in train_relations:
                rel_items = remove_unseen_relation(train_data_dict[rel],
                                                   seen_relations,
                                                   dataset=opt.dataset)
                rel_memo = select_data(inner_model, rel_items,
                                       rel_instance_num, relation_numbers,
                                       opt.batch_size, device)
                rel2instance_memory[rel] = rel_memo

        if opt.task_memory_size > 0:
            # sample memory from current_train_data
            if opt.memory_select_method == 'random':
                memory_data.append(
                    random_select_data(current_train_data,
                                       int(opt.task_memory_size)))
            elif opt.memory_select_method == 'vec_cluster':
                selected_memo = select_data(inner_model, current_train_data,
                                            int(opt.task_memory_size),
                                            relation_numbers, opt.batch_size,
                                            device)
                memory_data.append(
                    selected_memo
                )  # memorydata是一个list,list中的每个元素都是一个包含selected_num个sample的list
                memory_pool.extend(selected_memo)
            elif opt.memory_select_method == 'difficulty':
                memory_data.append()

        print_list(results)
        avg_result = sum(results) / len(results)
        test_set_size = [len(testdata) for testdata in current_test_data]
        whole_result = sum([
            results[i] * test_set_size[i]
            for i in range(len(current_test_data))
        ]) / sum(test_set_size)
        print('test_set_size: [%s]' %
              ', '.join([str(size) for size in test_set_size]))
        print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result))

        # end of each task, get embeddings of all
        # if len(all_seen_relations) > 1:
        #     rel_embed = tsne_relations(inner_model, seen_task_relations, relation_numbers, device, task_sequence[opt.sequence_index])
        #     rel_embeddings.append(rel_embed)

    print('test_all:')
    for epoch in range(10):
        current_test_data = []
        for previous_task_id in range(opt.task_num):
            current_test_data.append(
                remove_unseen_relation(split_test_data[previous_task_id],
                                       seen_relations,
                                       dataset=opt.dataset))

        loss_function = nn.MarginRankingLoss(opt.loss_margin)
        optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate)
        optimizer.zero_grad()
        for one_batch_memory in memory_data:
            scores, loss = feed_samples(inner_model, one_batch_memory,
                                        loss_function, relation_numbers,
                                        device)
            optimizer.step()
        results = [
            evaluate_model(inner_model, test_data, opt.batch_size,
                           relation_numbers, device)
            for test_data in current_test_data
        ]
        print(results)
        avg_result = sum(results) / len(results)
        test_set_size = [len(testdata) for testdata in current_test_data]
        whole_result = sum([
            results[i] * test_set_size[i]
            for i in range(len(current_test_data))
        ]) / sum(test_set_size)
        print('test_set_size: [%s]' %
              ', '.join([str(size) for size in test_set_size]))
        print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result))
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda_id',
                        default=0,
                        type=int,
                        help='cuda device index, -1 means use cpu')
    parser.add_argument('--train_file',
                        default='dataset/training_data.txt',
                        help='train file')
    parser.add_argument('--valid_file',
                        default='dataset/val_data.txt',
                        help='valid file')
    parser.add_argument('--test_file',
                        default='dataset/val_data.txt',
                        help='test file')
    parser.add_argument('--relation_file',
                        default='dataset/relation_name.txt',
                        help='relation name file')
    parser.add_argument('--glove_file',
                        default='dataset/glove.6B.50d.txt',
                        help='glove embedding file')
    parser.add_argument('--embedding_dim',
                        default=50,
                        type=int,
                        help='word embeddings dimensional')
    parser.add_argument('--hidden_dim',
                        default=200,
                        type=int,
                        help='BiLSTM hidden dimensional')
    parser.add_argument(
        '--task_arrange',
        default='random',
        help='task arrangement method, e.g. cluster_by_glove_embedding, random'
    )
    parser.add_argument('--rel_encode',
                        default='glove',
                        help='relation encode method')
    parser.add_argument(
        '--meta_method',
        default='reptile',
        help='meta learning method, maml and reptile can be choose')
    parser.add_argument('--batch_size',
                        default=50,
                        type=float,
                        help='Reptile inner loop batch size')
    parser.add_argument('--task_num',
                        default=10,
                        type=int,
                        help='number of tasks')
    parser.add_argument(
        '--train_instance_num',
        default=200,
        type=int,
        help='number of instances for one relation, -1 means all.')
    parser.add_argument('--loss_margin',
                        default=0.5,
                        type=float,
                        help='loss margin setting')
    parser.add_argument('--outside_epoch',
                        default=200,
                        type=float,
                        help='task level epoch')
    parser.add_argument('--early_stop',
                        default=20,
                        type=float,
                        help='task level epoch')
    parser.add_argument('--step_size',
                        default=0.4,
                        type=float,
                        help='step size Epsilon')
    parser.add_argument('--learning_rate',
                        default=2e-2,
                        type=float,
                        help='learning rate')
    parser.add_argument('--random_seed',
                        default=317,
                        type=int,
                        help='random seed')
    parser.add_argument('--task_memory_size',
                        default=50,
                        type=int,
                        help='number of samples for each task')
    parser.add_argument(
        '--memory_select_method',
        default='vec_cluster',
        help=
        'the method of sample memory data, e.g. vec_cluster, random, difficulty'
    )

    opt = parser.parse_args()
    print(opt)
    random.seed(opt.random_seed)
    torch.manual_seed(opt.random_seed)
    np.random.seed(opt.random_seed)
    np.random.RandomState(opt.random_seed)

    device = torch.device((
        'cuda:%d' % opt.cuda_id
    ) if torch.cuda.is_available() and opt.cuda_id >= 0 else 'cpu')

    # do following process
    split_train_data, train_data_dict, split_test_data, test_data_dict, split_valid_data, valid_data_dict, \
    relation_numbers, rel_features, vocabulary, embedding  = \
        load_data(opt.train_file, opt.valid_file, opt.test_file, opt.relation_file, opt.glove_file,
                  opt.embedding_dim, opt.task_arrange, opt.rel_encode, opt.task_num,
                  opt.train_instance_num)
    # prepare model
    inner_model = SimilarityModel(opt.embedding_dim, opt.hidden_dim,
                                  len(vocabulary), np.array(embedding), 1,
                                  device)

    memory_data = []  # B
    seen_relations = []
    for task_index in range(opt.task_num):  # outside loop
        # reptile start model parameters pi
        weights_before = deepcopy(inner_model.state_dict())
        train_task = split_train_data[task_index]
        # test_task = split_test_data[task_index]
        valid_task = split_valid_data[task_index]
        # collect seen relations
        for data_item in train_task:
            if data_item[0] not in seen_relations:
                seen_relations.append(data_item[0])

        # remove unseen relations
        current_train_data = remove_unseen_relation(train_task, seen_relations)
        current_valid_data = remove_unseen_relation(valid_task, seen_relations)
        current_test_data = []
        for previous_task_id in range(task_index + 1):
            current_test_data.append(
                remove_unseen_relation(split_test_data[previous_task_id],
                                       seen_relations))

        # train inner_model
        loss_function = nn.MarginRankingLoss(opt.loss_margin)
        inner_model = inner_model.to(device)

        t = tqdm(range(opt.outside_epoch))
        best_valid_acc = 0.0
        early_stop = 0
        best_checkpoint = ''
        for epoch in t:
            weights_task = deepcopy(weights_before)
            # optimizer.zero_grad()
            # inner_model.load_state_dict({name: weights_before[name] for name in weights_before})
            batch_num = (len(current_train_data) -
                         1) // opt.train_instance_num + 1
            total_loss = 0.0
            weights_list = [None] * (batch_num + len(memory_data))
            for batch in range(batch_num):
                # one relation's train data
                batch_train_data = current_train_data[batch *
                                                      opt.train_instance_num:
                                                      (batch + 1) *
                                                      opt.train_instance_num]
                inner_model.load_state_dict(weights_task)
                optimizer = optim.SGD(inner_model.parameters(),
                                      lr=opt.learning_rate)
                optimizer.zero_grad()
                scores, loss = feed_samples(inner_model, batch_train_data,
                                            loss_function, relation_numbers,
                                            device)

                loss.backward()  # 计算反向传播梯度
                # 更新参数
                # for f in inner_model.parameters():
                #     f.data.sub_(f.grad.data * opt.learning_rate)
                optimizer.step()  # 更新参数
                total_loss += loss
                weights_list[batch] = deepcopy(
                    inner_model.state_dict())  # 保存theta_t^i

            if len(memory_data) > 0:
                for i in range(len(memory_data)):
                    one_batch_memory = memory_data[i]
                    inner_model.load_state_dict(weights_task)
                    optimizer = optim.SGD(inner_model.parameters(),
                                          lr=opt.learning_rate)
                    optimizer.zero_grad()
                    scores, loss = feed_samples(inner_model, one_batch_memory,
                                                loss_function,
                                                relation_numbers, device)

                    loss.backward()
                    # 更新参数
                    # for f in inner_model.parameters():
                    #     f.data.sub_(f.grad.data * opt.learning_rate)
                    optimizer.step()
                    total_loss += loss
                    weights_list[batch_num + i] = deepcopy(
                        inner_model.state_dict())

            outer_step_size = opt.step_size * (1 / len(weights_list))
            for name in weights_before:
                weights_task[
                    name] = weights_before[name] - outer_step_size * sum([
                        weights[name] - weights_before[name]
                        for weights in weights_list
                    ])

            # load state dict of weights_after
            inner_model.load_state_dict(weights_task)
            # weights_before = deepcopy(inner_model.state_dict())

            del weights_list

            valid_acc = evaluate_model(inner_model, current_valid_data,
                                       opt.batch_size, relation_numbers,
                                       device)

            checkpoint = {'net_state': inner_model.state_dict()}
            if valid_acc > best_valid_acc:
                best_checkpoint = './checkpoint/checkpoint_task%d_epoch%d.pth.tar' % (
                    task_index, epoch)
                torch.save(checkpoint, best_checkpoint)
                best_valid_acc = valid_acc
                early_stop = 0
            else:
                early_stop += 1

            # print()
            t.set_description('Task %i Epoch %i' % (task_index + 1, epoch + 1))
            t.set_postfix(loss=total_loss.item(),
                          valid_acc=valid_acc,
                          early_stop=early_stop,
                          best_checkpoint=best_checkpoint)
            t.update(1)

            if early_stop >= opt.early_stop:
                # 已经充分训练了
                break
        t.close()

        # sample memory from current_train_data
        if opt.memory_select_method == 'random':
            memory_data.append(
                random_select_data(current_train_data, opt.task_memory_size))
        elif opt.memory_select_method == 'vec_cluster':
            memory_data.append(
                select_data(inner_model, current_train_data,
                            opt.task_memory_size, relation_numbers,
                            opt.batch_size, device)
            )  # memorydata是一个list,list中的每个元素都是一个包含selected_num个sample的list
        elif opt.memory_select_method == 'difficulty':
            memory_data.append()

        results = [
            evaluate_model(inner_model, test_data, opt.batch_size,
                           relation_numbers, device)
            for test_data in current_test_data
        ]  # 使用current model和alignment model对test data进行一个预测

        print(results)
def train_memory(training_data,
                 valid_data,
                 vocabulary,
                 embedding_dim,
                 hidden_dim,
                 device,
                 batch_size,
                 lr,
                 model_path,
                 embedding,
                 all_relations,
                 model=None,
                 epoch=100,
                 memory_data=[],
                 loss_margin=0.5,
                 past_fisher=None,
                 rel_samples=[],
                 relation_frequences=[],
                 rel_embeds=None,
                 rel_ques_cand=None,
                 rel_acc_diff=None,
                 all_seen_rels=None,
                 update_rel_embed=None,
                 reverse_model=None,
                 memory_que_embed=[],
                 memory_rel_embed=[],
                 to_update_reverse=False):
    if model is None:
        torch.manual_seed(100)
        model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary),
                                np.array(embedding), 1, device)
    loss_function = nn.MarginRankingLoss(loss_margin)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    reverse_optimiser = torch.optim.Adam(reverse_model.parameters(), lr=lr)
    best_acc = 0
    #acc_pre=evaluate_model(model, valid_data, batch_size, all_relations, device)
    given_pro = None
    memory_index = 0
    memory_data_grads = []
    for epoch_i in range(epoch):
        #print('epoch', epoch_i)
        #training_data = training_data[0:100]
        #for i in range((len(training_data)-1)//batch_size+1):
        for i, samples in enumerate(memory_data):
            #samples = training_data[i*batch_size:(i+1)*batch_size]
            seed_rels = []
            for item in samples:
                if item[0] not in seed_rels:
                    seed_rels.append(item[0])

            #start_time = time.time()
            '''
            if len(memory_data) > 0:
                all_seen_data = []
                for this_memory in memory_data:
                    all_seen_data+=this_memory
                memory_batch = memory_data[memory_index]
                #memory_batch = random.sample(all_seen_data,
                #                             min(batch_size, len(all_seen_data)))
                #print(memory_data)
                scores, loss = feed_samples(model, memory_batch,
                                            loss_function,
                                            all_relations, device, reverse_model)
                if to_update_reverse:
                    reverse_optimiser.step()
                else:
                    optimizer.step()
                memory_index = (memory_index+1)%len(memory_data)
                '''
            scores, loss = feed_samples(model, samples, loss_function,
                                        all_relations, device, reverse_model)
            #memory_que_embed[i], memory_rel_embed[i])
            #end_time = time.time()
            #print('forward time:', end_time - start_time)
            sample_grad = copy_grad_data(model)
            if to_update_reverse:
                reverse_optimiser.step()
            else:
                optimizer.step()
            del scores
            del loss
            '''
        acc=evaluate_model(model, valid_data, batch_size, all_relations, device)
        if acc > best_acc:
            torch.save(model, model_path)
    best_model = torch.load(model_path)
    return best_model
    '''
    #acc_aft=evaluate_model(model, valid_data, batch_size, all_relations, device)
    #return model, max(0, acc_aft-acc_pre)
    if to_update_reverse:
        return reverse_model, 0
    else:
        return model, 0
def train(training_data,
          valid_data,
          vocabulary,
          embedding_dim,
          hidden_dim,
          device,
          batch_size,
          lr,
          model_path,
          embedding,
          all_relations,
          model=None,
          epoch=100,
          memory_data=[],
          loss_margin=0.5,
          past_fisher=None,
          rel_samples=[],
          relation_frequences=[],
          rel_embeds=None,
          rel_ques_cand=None,
          rel_acc_diff=None,
          all_seen_rels=None,
          update_rel_embed=None,
          reverse_model=None,
          memory_que_embed=[],
          memory_rel_embed=[],
          to_update_reverse=False):
    if model is None:
        torch.manual_seed(100)
        model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary),
                                np.array(embedding), 1, device)
    loss_function = nn.MarginRankingLoss(loss_margin)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    reverse_optimiser = torch.optim.Adam(reverse_model.parameters(), lr=lr)
    best_acc = 0
    #acc_pre=evaluate_model(model, valid_data, batch_size, all_relations, device)
    given_pro = None
    memory_index = 0
    memory_data_grads = []
    for epoch_i in range(epoch):
        #print('epoch', epoch_i)
        #training_data = training_data[0:100]
        for i in range((len(training_data) - 1) // batch_size + 1):
            samples = training_data[i * batch_size:(i + 1) * batch_size]
            seed_rels = []
            for item in samples:
                if item[0] not in seed_rels:
                    seed_rels.append(item[0])
                '''
                for item_cand in item[1]:
                    if item_cand not in seed_rels:
                        seed_rels.append(item_cand)
                '''

            if len(rel_samples) > 0:
                memory_data = sample_constrains(rel_samples,
                                                relation_frequences,
                                                rel_embeds, seed_rels,
                                                rel_ques_cand, rel_acc_diff,
                                                given_pro, all_seen_rels)
            '''
            to_train_mem = memory_data
            if len(memory_data) > num_constrain:
                to_train_mem = random.sample(memory_data, num_constrain)
            memory_data_grads = get_grads_memory_data(model, to_train_mem,
                                                      loss_function,
                                                      all_relations,
                                                      device, reverse_model,
                                                      memory_que_embed,
                                                      memory_rel_embed)
                                                      '''
            #print(memory_data_grads)
            #start_time = time.time()
            if len(memory_data) > 0:
                all_seen_data = []
                for this_memory in memory_data:
                    all_seen_data += this_memory
                memory_batch = memory_data[memory_index]
                #memory_batch = random.sample(all_seen_data,
                #                             min(batch_size, len(all_seen_data)))
                #print(memory_data)
                scores, loss = feed_samples(model, memory_batch, loss_function,
                                            all_relations, device,
                                            reverse_model)
                if to_update_reverse:
                    reverse_optimiser.step()
                else:
                    optimizer.step()
                memory_index = (memory_index + 1) % len(memory_data)
            scores, loss = feed_samples(model, samples, loss_function,
                                        all_relations, device, reverse_model)
            #end_time = time.time()
            #print('forward time:', end_time - start_time)
            sample_grad = copy_grad_data(model)
            if len(memory_data_grads) > 0:
                #if not check_constrain(memory_data_grads, sample_grad):
                if True:
                    project2cone2(sample_grad, memory_data_grads)
                    if past_fisher is None:
                        grad_params = get_grad_params(model)
                        grad_dims = [
                            param.data.numel() for param in grad_params
                        ]
                        overwrite_grad(grad_params, sample_grad, grad_dims)
            if past_fisher is not None:
                sample_grad = rescale_grad(sample_grad, past_fisher)
                grad_params = get_grad_params(model)
                grad_dims = [param.data.numel() for param in grad_params]
                overwrite_grad(grad_params, sample_grad, grad_dims)
            if to_update_reverse:
                reverse_optimiser.step()
            else:
                optimizer.step()
            #optimizer.step()
            if (epoch_i % 5 == 0) and len(relation_frequences) > 0 and False:
                update_rel_embed(model, all_seen_rels, all_relations,
                                 rel_embeds)
                samples = list(relation_frequences.keys())
                #return random.sample(samples, min(len(samples), num_samples))
                sample_embeds = torch.from_numpy(
                    np.asarray([rel_embeds[i] for i in samples]))
                #seed_rel_embeds = torch.from_numpy(np.asarray(
                #    [rel_embeds[i] for i in seed_rels])).to(device)
                sample_embeds_np = sample_embeds.cpu().double().numpy()
                #given_pro = kmeans_pro(sample_embeds_np, samples, num_constrain)
            if epoch_i % 5 == 0 and False:
                update_rel_embed(model, all_seen_rels, all_relations,
                                 rel_embeds)
                #update_rel_cands(memory_data, all_seen_rels, rel_embeds)
            del scores
            del loss
            '''
        acc=evaluate_model(model, valid_data, batch_size, all_relations, device)
        if acc > best_acc:
            torch.save(model, model_path)
    best_model = torch.load(model_path)
    return best_model
    '''
    #acc_aft=evaluate_model(model, valid_data, batch_size, all_relations, device)
    #return model, max(0, acc_aft-acc_pre)
    if to_update_reverse:
        return reverse_model, 0
    else:
        return model, 0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda_id',
                        default=0,
                        type=int,
                        help='cuda device index, -1 means use cpu')
    parser.add_argument('--train_file',
                        default='dataset/training_data.txt',
                        help='train file')
    parser.add_argument('--valid_file',
                        default='dataset/val_data.txt',
                        help='valid file')
    parser.add_argument('--test_file',
                        default='dataset/val_data.txt',
                        help='test file')
    parser.add_argument('--relation_file',
                        default='dataset/relation_name.txt',
                        help='relation name file')
    parser.add_argument('--glove_file',
                        default='dataset/glove.6B.300d.txt',
                        help='glove embedding file')
    parser.add_argument('--kl_dist_file',
                        default='dataset/kl_dist_ht.json',
                        help='glove embedding file')
    parser.add_argument('--embedding_dim',
                        default=300,
                        type=int,
                        help='word embeddings dimensional')
    parser.add_argument('--hidden_dim',
                        default=200,
                        type=int,
                        help='BiLSTM hidden dimensional')
    parser.add_argument(
        '--task_arrange',
        default='cluster_by_glove_embedding',
        help='task arrangement method, e.g. cluster_by_glove_embedding, random'
    )
    parser.add_argument('--rel_encode',
                        default='glove',
                        help='relation encode method')
    parser.add_argument(
        '--meta_method',
        default='reptile',
        help='meta learning method, maml and reptile can be choose')
    parser.add_argument('--batch_size',
                        default=50,
                        type=float,
                        help='Reptile inner loop batch size')
    parser.add_argument('--task_num',
                        default=10,
                        type=int,
                        help='number of tasks')
    parser.add_argument(
        '--train_instance_num',
        default=200,
        type=int,
        help='number of instances for one relation, -1 means all.')
    parser.add_argument('--loss_margin',
                        default=0.5,
                        type=float,
                        help='loss margin setting')
    parser.add_argument('--outside_epoch',
                        default=200,
                        type=float,
                        help='task level epoch')
    parser.add_argument('--early_stop',
                        default=20,
                        type=float,
                        help='task level epoch')
    parser.add_argument('--step_size',
                        default=0.6,
                        type=float,
                        help='step size Epsilon')
    parser.add_argument('--learning_rate',
                        default=2e-3,
                        type=float,
                        help='learning rate')
    parser.add_argument('--random_seed',
                        default=226,
                        type=int,
                        help='random seed')
    parser.add_argument('--task_memory_size',
                        default=50,
                        type=int,
                        help='number of samples for each task')
    parser.add_argument(
        '--memory_select_method',
        default='vec_cluster',
        help=
        'the method of sample memory data, e.g. vec_cluster, random, difficulty'
    )
    parser.add_argument(
        '--curriculum_rel_num',
        default=3,
        help=
        'curriculum learning relation sampled number for current training relation'
    )
    parser.add_argument(
        '--curriculum_instance_num',
        default=5,
        help=
        'curriculum learning instance sampled number for a sampled relation')

    opt = parser.parse_args()
    print(opt)
    random.seed(opt.random_seed)
    torch.manual_seed(opt.random_seed)
    np.random.seed(opt.random_seed)
    np.random.RandomState(opt.random_seed)

    device = torch.device((
        'cuda:%d' % opt.cuda_id
    ) if torch.cuda.is_available() and opt.cuda_id >= 0 else 'cpu')

    # do following process
    split_train_data, train_data_dict, split_test_data, test_data_dict, split_valid_data, valid_data_dict, \
    relation_numbers, rel_features, vocabulary, embedding = \
        load_data(opt.train_file, opt.valid_file, opt.test_file, opt.relation_file, opt.glove_file,
                  opt.embedding_dim, opt.task_arrange, opt.rel_encode, opt.task_num,
                  opt.train_instance_num)

    # kl similarity of the joint distribution of head and tail
    kl_dist_ht = read_json(opt.kl_dist_file)

    # tmp = [[0, 1, 2, 3], [1, 0, 4, 6], [2, 4, 0, 5], [3, 6, 5, 0]]
    sorted_sililarity_index = np.argsort(-np.asarray(kl_dist_ht), axis=1) + 1

    # prepare model
    inner_model = SimilarityModel(opt.embedding_dim, opt.hidden_dim,
                                  len(vocabulary), np.array(embedding), 1,
                                  device)

    memory_data = []
    memory_question_embed = []
    memory_relation_embed = []
    sequence_results = []
    result_whole_test = []
    seen_relations = []
    all_seen_relations = []
    memory_index = 0
    for task_index in range(opt.task_num):  # outside loop
        # reptile start model parameters pi
        weights_before = deepcopy(inner_model.state_dict())

        train_task = split_train_data[task_index]
        test_task = split_test_data[task_index]
        valid_task = split_valid_data[task_index]

        # collect seen relations
        for data_item in train_task:
            if data_item[0] not in seen_relations:
                seen_relations.append(data_item[0])

        # remove unseen relations
        current_train_data = remove_unseen_relation(train_task, seen_relations)
        current_valid_data = remove_unseen_relation(valid_task, seen_relations)
        current_test_data = []
        for previous_task_id in range(task_index + 1):
            current_test_data.append(
                remove_unseen_relation(split_test_data[previous_task_id],
                                       seen_relations))

        # train inner_model
        loss_function = nn.MarginRankingLoss(opt.loss_margin)
        inner_model = inner_model.to(device)
        optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate)
        t = tqdm(range(opt.outside_epoch))
        best_valid_acc = 0.0
        early_stop = 0
        best_checkpoint = ''
        for epoch in t:
            batch_num = (len(current_train_data) - 1) // opt.batch_size + 1
            total_loss = 0.0
            for batch in range(batch_num):
                batch_train_data = current_train_data[batch *
                                                      opt.batch_size:(batch +
                                                                      1) *
                                                      opt.batch_size]

                if len(memory_data) > 0:
                    all_seen_data = []
                    for one_batch_memory in memory_data:
                        all_seen_data += one_batch_memory

                    memory_batch = memory_data[memory_index]
                    batch_train_data.extend(memory_batch)
                    # scores, loss = feed_samples(inner_model, memory_batch, loss_function, relation_numbers, device)
                    # optimizer.step()
                    memory_index = (memory_index + 1) % len(memory_data)
                # random.shuffle(batch_train_data)

                # curriculum before batch_train
                if task_index > 0:
                    current_train_rel = batch_train_data[0][0]
                    current_rel_similarity_sorted_index = sorted_sililarity_index[
                        current_train_rel + 1]
                    seen_relation_sorted_index = []
                    for rel in current_rel_similarity_sorted_index:
                        if rel in seen_relations:
                            seen_relation_sorted_index.append(rel)

                    curriculum_rel_list = []
                    if opt.curriculum_rel_num >= len(
                            seen_relation_sorted_index):
                        curriculum_rel_list = seen_relation_sorted_index[:]
                    else:
                        step = len(seen_relation_sorted_index
                                   ) // opt.curriculum_rel_num
                        for i in range(0, len(seen_relation_sorted_index),
                                       step):
                            curriculum_rel_list.append(
                                seen_relation_sorted_index[i])

                    curriculum_instance_list = []
                    for curriculum_rel in curriculum_rel_list:
                        curriculum_instance_list.extend(
                            random.sample(train_data_dict[curriculum_rel],
                                          opt.curriculum_instance_num))

                    curriculum_instance_list = remove_unseen_relation(
                        curriculum_instance_list, seen_relations)
                    # optimizer.zero_grad()
                    scores, loss = feed_samples(inner_model,
                                                curriculum_instance_list,
                                                loss_function,
                                                relation_numbers, device)
                    # loss.backward()
                    optimizer.step()

                scores, loss = feed_samples(inner_model, batch_train_data,
                                            loss_function, relation_numbers,
                                            device)
                optimizer.step()
                total_loss += loss

            # valid test
            valid_acc = evaluate_model(inner_model, current_valid_data,
                                       opt.batch_size, relation_numbers,
                                       device)
            # checkpoint
            checkpoint = {
                'net_state': inner_model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            if valid_acc > best_valid_acc:
                best_checkpoint = './checkpoint/checkpoint_task%d_epoch%d.pth.tar' % (
                    task_index, epoch)
                torch.save(checkpoint, best_checkpoint)
                best_valid_acc = valid_acc
                early_stop = 0
            else:
                early_stop += 1

            # print()
            t.set_description('Task %i Epoch %i' % (task_index + 1, epoch + 1))
            t.set_postfix(loss=total_loss.item(),
                          valid_acc=valid_acc,
                          early_stop=early_stop,
                          best_checkpoint=best_checkpoint)
            t.update(1)

            if early_stop >= opt.early_stop:
                # 已经充分训练了
                break
        t.close()
        print('Load best check point from %s' % best_checkpoint)
        checkpoint = torch.load(best_checkpoint)

        weights_after = checkpoint['net_state']

        # weights_after = inner_model.state_dict()  # 经过inner_epoch轮次的梯度更新后weights
        # if task_index == opt.task_num - 1:
        #     outer_step_size = opt.step_size * (1 - 5 / opt.task_num)
        # else:
        outer_step_size = math.sqrt(opt.step_size *
                                    (1 - task_index / opt.task_num))
        # outer_step_size = opt.step_size / opt.task_num
        # outer_step_size = 0.4
        inner_model.load_state_dict({
            name: weights_before[name] +
            (weights_after[name] - weights_before[name]) * outer_step_size
            for name in weights_before
        })

        # 用memory进行训练:
        # for i in range(5):
        #     for one_batch_memory in memory_data:
        #         scores, loss = feed_samples(inner_model, one_batch_memory, loss_function, relation_numbers, device)
        #         optimizer.step()

        results = [
            evaluate_model(inner_model, test_data, opt.batch_size,
                           relation_numbers, device)
            for test_data in current_test_data
        ]  # 使用current model和alignment model对test data进行一个预测

        # sample memory from current_train_data
        if opt.memory_select_method == 'random':
            memory_data.append(
                random_select_data(current_train_data,
                                   int(opt.task_memory_size / results[-1])))
        elif opt.memory_select_method == 'vec_cluster':
            memory_data.append(
                select_data(inner_model, current_train_data,
                            int(opt.task_memory_size / results[-1]),
                            relation_numbers, opt.batch_size, device)
            )  # memorydata是一个list,list中的每个元素都是一个包含selected_num个sample的list
        elif opt.memory_select_method == 'difficulty':
            memory_data.append()

        print_list(results)
        avg_result = sum(results) / len(results)
        test_set_size = [len(testdata) for testdata in current_test_data]
        whole_result = sum([
            results[i] * test_set_size[i]
            for i in range(len(current_test_data))
        ]) / sum(test_set_size)
        print('test_set_size: [%s]' %
              ', '.join([str(size) for size in test_set_size]))
        print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result))

    print('test_all:')
    for epoch in range(10):
        current_test_data = []
        for previous_task_id in range(opt.task_num):
            current_test_data.append(
                remove_unseen_relation(split_test_data[previous_task_id],
                                       seen_relations))

        loss_function = nn.MarginRankingLoss(opt.loss_margin)
        optimizer = optim.Adam(inner_model.parameters(), lr=opt.learning_rate)
        optimizer.zero_grad()
        for one_batch_memory in memory_data:
            scores, loss = feed_samples(inner_model, one_batch_memory,
                                        loss_function, relation_numbers,
                                        device)
            optimizer.step()
        results = [
            evaluate_model(inner_model, test_data, opt.batch_size,
                           relation_numbers, device)
            for test_data in current_test_data
        ]
        print(results)
        avg_result = sum(results) / len(results)
        test_set_size = [len(testdata) for testdata in current_test_data]
        whole_result = sum([
            results[i] * test_set_size[i]
            for i in range(len(current_test_data))
        ]) / sum(test_set_size)
        print('test_set_size: [%s]' %
              ', '.join([str(size) for size in test_set_size]))
        print('avg_acc: %.3f, whole_acc: %.3f' % (avg_result, whole_result))
Beispiel #9
0
def train(training_data,
          valid_data,
          vocabulary,
          embedding_dim,
          hidden_dim,
          device,
          batch_size,
          lr,
          model_path,
          embedding,
          all_relations,
          model=None,
          epoch=100,
          grad_means=[],
          grad_fishers=[],
          loss_margin=2.0):
    if model is None:
        torch.manual_seed(100)
        model = SimilarityModel(embedding_dim, hidden_dim, len(vocabulary),
                                np.array(embedding), 1, device)
    loss_function = nn.MarginRankingLoss(loss_margin)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    best_acc = 0
    for epoch_i in range(epoch):
        #print('epoch', epoch_i)
        #training_data = training_data[0:100]
        for i in range((len(training_data) - 1) // batch_size + 1):
            samples = training_data[i * batch_size:(i + 1) * batch_size]
            questions, relations, relation_set_lengths = process_samples(
                samples, all_relations, device)
            #print('got data')
            ranked_questions, reverse_question_indexs = \
                ranking_sequence(questions)
            ranked_relations, reverse_relation_indexs =\
                ranking_sequence(relations)
            question_lengths = [len(question) for question in ranked_questions]
            relation_lengths = [len(relation) for relation in ranked_relations]
            #print(ranked_questions)
            pad_questions = torch.nn.utils.rnn.pad_sequence(ranked_questions)
            pad_relations = torch.nn.utils.rnn.pad_sequence(ranked_relations)
            #print(pad_questions)
            pad_questions = pad_questions.to(device)
            pad_relations = pad_relations.to(device)
            #print(pad_questions)

            model.zero_grad()
            model.init_hidden(device, sum(relation_set_lengths))
            all_scores = model(pad_questions, pad_relations, device,
                               reverse_question_indexs,
                               reverse_relation_indexs, question_lengths,
                               relation_lengths)
            all_scores = all_scores.to('cpu')
            pos_scores = []
            neg_scores = []
            start_index = 0
            for length in relation_set_lengths:
                pos_scores.append(all_scores[start_index].expand(length - 1))
                neg_scores.append(all_scores[start_index + 1:start_index +
                                             length])
                start_index += length
            pos_scores = torch.cat(pos_scores)
            neg_scores = torch.cat(neg_scores)

            loss = loss_function(
                pos_scores, neg_scores,
                torch.ones(
                    sum(relation_set_lengths) - len(relation_set_lengths)))
            loss = loss.sum()
            #loss.to(device)
            #print(loss)
            for i in range(len(grad_means)):
                grad_mean = grad_means[i]
                grad_fisher = grad_fishers[i]
                #print(param_loss(model, grad_mean, grad_fisher, p_lambda))
                loss += param_loss(model, grad_mean, grad_fisher,
                                   p_lambda).to('cpu')
            loss.backward()
            optimizer.step()
            '''
        acc=evaluate_model(model, valid_data, batch_size, all_relations, device)
        if acc > best_acc:
            torch.save(model, model_path)
    best_model = torch.load(model_path)
    return best_model
    '''
    return model