Exemplo n.º 1
0
    def train(self, org_train_dataset, org_dev_datasets, config):
        print('extracting training data')
        train_dataset = self.get_data_items(org_train_dataset,
                                            predict=False,
                                            isTrain=True)
        print('#train docs', len(train_dataset))
        self.init_lr = config['lr']
        dev_datasets = []
        for dname, data in org_dev_datasets:
            dev_datasets.append(
                (dname, self.get_data_items(data, predict=True,
                                            isTrain=False)))
            print(dname, '#dev docs', len(dev_datasets[-1][1]))

        print('creating optimizer')
        optimizer = optim.Adam(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=config['lr'])

        for param_name, param in self.model.named_parameters():
            if param.requires_grad:
                print(param_name)

        best_f1 = -1
        not_better_count = 0
        is_counting = False
        eval_after_n_epochs = self.args.eval_after_n_epochs

        order_learning = False
        # order_learning_count = 0

        rl_acc_threshold = 0.7

        # optimize the parameters within the disambiguation module first
        best_aida_A_rlts = []
        best_aida_A_f1 = 0.
        best_aida_B_rlts = []
        best_aida_B_f1 = 0.
        best_ave_rlts = []
        best_ave_f1 = 0.
        self.run_time = []
        for e in range(config['n_epochs']):
            shuffle(train_dataset)

            total_loss = 0
            for dc, batch in enumerate(
                    train_dataset):  # each document is a minibatch
                self.model.train()

                # convert data items to pytorch inputs
                token_ids = [
                    m['context'][0] + m['context'][1]
                    if len(m['context'][0]) + len(m['context'][1]) > 0 else
                    [self.model.word_voca.unk_id] for m in batch
                ]

                ment_ids = [
                    m['ment_ids'] if len(m['ment_ids']) > 0 else
                    [self.model.word_voca.unk_id] for m in batch
                ]

                entity_ids = Variable(
                    torch.LongTensor(
                        [m['selected_cands']['cands'] for m in batch]).cuda())
                true_pos = Variable(
                    torch.LongTensor([
                        m['selected_cands']['true_pos'] for m in batch
                    ]).cuda())
                p_e_m = Variable(
                    torch.FloatTensor(
                        [m['selected_cands']['p_e_m'] for m in batch]).cuda())
                entity_mask = Variable(
                    torch.FloatTensor(
                        [m['selected_cands']['mask'] for m in batch]).cuda())

                mtype = Variable(
                    torch.FloatTensor([m['mtype'] for m in batch]).cuda())
                etype = Variable(
                    torch.FloatTensor(
                        [m['selected_cands']['etype'] for m in batch]).cuda())

                token_ids, token_mask = utils.make_equal_len(
                    token_ids, self.model.word_voca.unk_id)
                token_ids = Variable(torch.LongTensor(token_ids).cuda())
                token_mask = Variable(torch.FloatTensor(token_mask).cuda())

                ment_ids, ment_mask = utils.make_equal_len(
                    ment_ids, self.model.word_voca.unk_id)
                ment_ids = Variable(torch.LongTensor(ment_ids).cuda())
                ment_mask = Variable(torch.FloatTensor(ment_mask).cuda())

                if self.args.method == "SL":
                    optimizer.zero_grad()

                    scores, _ = self.model.forward(
                        token_ids,
                        token_mask,
                        entity_ids,
                        entity_mask,
                        p_e_m,
                        mtype,
                        etype,
                        ment_ids,
                        ment_mask,
                        gold=true_pos.view(-1, 1),
                        method=self.args.method,
                        isTrain=True,
                        isDynamic=config['isDynamic'],
                        isOrderLearning=order_learning,
                        isOrderFixed=True,
                        isSort=self.args.sort)

                    if order_learning:
                        _, targets = self.model.get_order_truth()
                        targets = Variable(torch.LongTensor(targets).cuda())

                        if scores.size(0) != targets.size(0):
                            print("Size mismatch!")
                            break
                        loss = self.model.loss(scores,
                                               targets,
                                               method=self.args.method)
                    else:
                        loss = self.model.loss(scores,
                                               true_pos,
                                               method=self.args.method)

                    loss.backward()
                    optimizer.step()
                    self.model.regularize(max_norm=4)

                    loss = loss.cpu().data.numpy()
                    total_loss += loss

                elif self.args.method == "RL":
                    action_memory = []
                    early_stop_count = 0

                    # the actual episode number for one doc is determined by decision accuracy
                    for i_episode in count(1):
                        optimizer.zero_grad()

                        # get the model output
                        scores, actions = self.model.forward(
                            token_ids,
                            token_mask,
                            entity_ids,
                            entity_mask,
                            p_e_m,
                            mtype,
                            etype,
                            ment_ids,
                            ment_mask,
                            gold=true_pos.view(-1, 1),
                            method=self.args.method,
                            isTrain=True,
                            isDynamic=config['isDynamic'],
                            isOrderLearning=order_learning,
                            isOrderFixed=True,
                            isSort=self.args.sort)
                        if order_learning:
                            _, targets = self.model.get_order_truth()
                            targets = Variable(
                                torch.LongTensor(targets).cuda())

                            if scores.size(0) != targets.size(0):
                                print("Size mismatch!")
                                break

                            loss = self.model.loss(scores,
                                                   targets,
                                                   method=self.args.method)
                        else:
                            loss = self.model.loss(scores,
                                                   true_pos,
                                                   method=self.args.method)

                        loss.backward()
                        optimizer.step()

                        loss = loss.cpu().data.numpy()
                        total_loss += loss

                        # compute accuracy
                        correct = 0
                        total = 0.
                        if order_learning:
                            _, targets = self.model.get_order_truth()
                            for i in range(len(actions)):
                                if targets[i] == actions[i]:
                                    correct += 1
                                total += 1
                        else:
                            for i in range(len(actions)):
                                if true_pos.data[i] == actions[i]:
                                    correct += 1
                                total += 1

                        if not config['use_early_stop']:
                            break

                        if i_episode > len(batch) / 2:
                            break

                        if actions == action_memory:
                            early_stop_count += 1
                        else:
                            del action_memory[:]
                            action_memory = copy.deepcopy(actions)
                            early_stop_count = 0

                        if correct / total >= rl_acc_threshold or early_stop_count >= 3:
                            break

            print('epoch',
                  e,
                  'total loss',
                  total_loss,
                  total_loss / len(train_dataset),
                  flush=True)

            if (e + 1) % eval_after_n_epochs == 0:
                dev_f1 = 0.
                test_f1 = 0.
                ave_f1 = 0.
                if rl_acc_threshold < 0.92:
                    rl_acc_threshold += 0.02
                temp_rlt = []
                #self.records[e] = dict()
                for di, (dname, data) in enumerate(dev_datasets):
                    if dname == 'aida-B':
                        self.rt_flag = True
                    else:
                        self.rt_flag = False
                    predictions = self.predict(data, config['isDynamic'],
                                               order_learning)

                    f1 = D.eval(org_dev_datasets[di][1], predictions)

                    print(dname,
                          utils.tokgreen('micro F1: ' + str(f1)),
                          flush=True)

                    with open(self.output_path, 'a') as eval_csv_f1:
                        eval_f1_csv_writer = csv.writer(eval_csv_f1)
                        eval_f1_csv_writer.writerow([dname, e, 0, f1])

                    temp_rlt.append([dname, f1])
                    if dname == 'aida-A':
                        dev_f1 = f1
                    if dname == 'aida-B':
                        test_f1 = f1
                    ave_f1 += f1
                if dev_f1 > best_aida_A_f1:
                    best_aida_A_f1 = dev_f1
                    best_aida_A_rlts = copy.deepcopy(temp_rlt)
                if test_f1 > best_aida_B_f1:
                    best_aida_B_f1 = test_f1
                    best_aida_B_rlts = copy.deepcopy(temp_rlt)
                if ave_f1 > best_ave_f1:
                    best_ave_f1 = ave_f1
                    best_ave_rlts = copy.deepcopy(temp_rlt)

                if not config['isDynamic']:
                    self.record_runtime('DCA')
                else:
                    self.record_runtime('local')

                if config[
                        'lr'] == self.init_lr and dev_f1 >= self.args.dev_f1_change_lr:
                    eval_after_n_epochs = 2
                    is_counting = True
                    best_f1 = dev_f1
                    not_better_count = 0

                    # self.model.switch_order_learning(0)
                    config['lr'] = self.init_lr / 2
                    print('change learning rate to', config['lr'])
                    optimizer = optim.Adam([
                        p for p in self.model.parameters() if p.requires_grad
                    ],
                                           lr=config['lr'])

                    for param_name, param in self.model.named_parameters():
                        if param.requires_grad:
                            print(param_name)

                if dev_f1 >= self.args.dev_f1_start_order_learning and self.args.order_learning:
                    order_learning = True

                if is_counting:
                    if dev_f1 < best_f1:
                        not_better_count += 1
                    else:
                        not_better_count = 0
                        best_f1 = dev_f1
                        print('save model to', self.args.model_path)
                        self.model.save(self.args.model_path)

                if not_better_count == self.args.n_not_inc:
                    break

                self.model.print_weight_norm()

        print('best_aida_A_rlts', best_aida_A_rlts)
        print('best_aida_B_rlts', best_aida_B_rlts)
        print('best_ave_rlts', best_ave_rlts)
Exemplo n.º 2
0
    def train(self, org_train_dataset, org_dev_datasets, config):
        print('extracting training data')
        train_dataset = self.get_data_items(org_train_dataset, predict=False)
        print('#train docs', len(train_dataset))

        dev_datasets = []
        for dname, data in org_dev_datasets:
            dev_datasets.append((dname, self.get_data_items(data,
                                                            predict=True)))
            print(dname, '#dev docs', len(dev_datasets[-1][1]))

        print('creating optimizer')
        optimizer = optim.Adam(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=config['lr'])
        best_f1 = -1
        not_better_count = 0
        is_counting = False
        eval_after_n_epochs = self.args.eval_after_n_epochs

        for e in range(config['n_epochs']):
            if self.args.method == "SL":
                shuffle(train_dataset)

            total_loss = 0

            for dc, batch in enumerate(
                    train_dataset):  # each document is a minibatch
                self.model.train()

                # convert data items to pytorch inputs
                token_ids = [
                    m['context'][0] + m['context'][1]
                    if len(m['context'][0]) + len(m['context'][1]) > 0 else
                    [self.model.word_voca.unk_id] for m in batch
                ]

                entity_ids = Variable(
                    torch.LongTensor(
                        [m['selected_cands']['cands'] for m in batch]).cuda())
                true_pos = Variable(
                    torch.LongTensor([
                        m['selected_cands']['true_pos'] for m in batch
                    ]).cuda())
                p_e_m = Variable(
                    torch.FloatTensor(
                        [m['selected_cands']['p_e_m'] for m in batch]).cuda())
                entity_mask = Variable(
                    torch.FloatTensor(
                        [m['selected_cands']['mask'] for m in batch]).cuda())

                mtype = Variable(
                    torch.FloatTensor([m['mtype'] for m in batch]).cuda())
                etype = Variable(
                    torch.FloatTensor(
                        [m['selected_cands']['etype'] for m in batch]).cuda())

                token_ids, token_mask = utils.make_equal_len(
                    token_ids, self.model.word_voca.unk_id)
                token_ids = Variable(torch.LongTensor(token_ids).cuda())
                token_mask = Variable(torch.FloatTensor(token_mask).cuda())

                if self.args.method == "SL" or self.args.method == "RL":
                    optimizer.zero_grad()

                    # get the model output
                    scores, _ = self.model.forward(token_ids,
                                                   token_mask,
                                                   entity_ids,
                                                   entity_mask,
                                                   p_e_m,
                                                   mtype,
                                                   etype,
                                                   gold=true_pos.view(-1, 1),
                                                   method=self.args.method,
                                                   isTrain=True)

                    loss = self.model.loss(scores,
                                           true_pos,
                                           method=self.args.method)

                    loss.backward()
                    optimizer.step()
                    self.model.regularize(max_norm=4)

                    loss = loss.cpu().data.numpy()
                    total_loss += loss

#                elif self.args.method == "RL":
#                    action_memory = []
#                    early_stop_count = 0

#                   for i_episode in count(1):  # the actual episode number for one doc is determined by decision accuracy
#                        optimizer.zero_grad()

# get the model output
#                        scores, actions = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m, mtype, etype,
#                                                       gold=true_pos.view(-1, 1), method=self.args.method, isTrain=True)

# compute accuracy
#                        correct = 0
#                        total = 0.
#                        for i in range(len(actions)):
#                            if true_pos.data[i] == actions[i]:
#                                correct += 1
#                            total += 1

#                        loss = self.model.loss(scores, true_pos, method=self.args.method)

#                        loss.backward()
#                        optimizer.step()

#                        loss = loss.cpu().data.numpy()
#                        total_loss += loss

#                        if i_episode > len(batch):
#                            break

#                        if actions == action_memory:
#                            early_stop_count += 1
#                        else:
#                            del action_memory[:]
#                            action_memory = copy.deepcopy(actions)
#                            early_stop_count = 0

#                        if correct/total >= 0.8 or early_stop_count >= 5:
#                            break

# print('epoch', e, "%0.2f%%" % (dc / len(train_dataset) * 100), loss)

            print('epoch',
                  e,
                  'total loss',
                  total_loss,
                  total_loss / len(train_dataset),
                  flush=True)

            if (e + 1) % eval_after_n_epochs == 0:
                dev_f1 = 0
                for di, (dname, data) in enumerate(dev_datasets):
                    predictions = self.predict(data)
                    f1 = D.eval(org_dev_datasets[di][1], predictions)
                    print(dname,
                          utils.tokgreen('micro F1: ' + str(f1)),
                          flush=True)

                    with open(self.output_path, 'a') as eval_csv_f1:
                        eval_f1_csv_writer = csv.writer(eval_csv_f1)
                        eval_f1_csv_writer.writerow([dname, e, f1])

                    if dname == 'aida-A':
                        dev_f1 = f1

                if config[
                        'lr'] == 1e-4 and dev_f1 >= self.args.dev_f1_change_lr:
                    eval_after_n_epochs = 2
                    is_counting = True
                    best_f1 = dev_f1
                    not_better_count = 0

                    config['lr'] = 1e-5
                    print('change learning rate to', config['lr'])
                    optimizer = optim.Adam([
                        p for p in self.model.parameters() if p.requires_grad
                    ],
                                           lr=config['lr'])

                if is_counting:
                    if dev_f1 < best_f1:
                        not_better_count += 1
                    else:
                        not_better_count = 0
                        best_f1 = dev_f1
                        print('save model to', self.args.model_path)
                        self.model.save(self.args.model_path)

                if not_better_count == self.args.n_not_inc:
                    break

                self.model.print_weight_norm()
Exemplo n.º 3
0
Arquivo: main.py Projeto: zjulins/DCA
                    ('aida-A', conll.testA),
                    ('aida-B', conll.testB),
                    ('msnbc', conll.msnbc),
                    ('aquaint', conll.aquaint),
                    ('ace2004', conll.ace2004),
                    ('clueweb', conll.clueweb),
                    ('wikipedia', conll.wikipedia)
                ]

    with open(F1_CSV_Path, 'w') as f_csv_f1:
        f1_csv_writer = csv.writer(f_csv_f1)
        f1_csv_writer.writerow(['dataset', 'epoch', 'dynamic', 'F1 Score'])

    if args.mode == 'train':
        print('training...')
        config = {'lr': args.learning_rate, 'n_epochs': args.n_epochs, 'isDynamic':args.isDynamic, 'use_early_stop' : args.use_early_stop,}
        # pprint(config)
        ranker.train(conll.train, dev_datasets, config)

    elif args.mode == 'eval':
        org_dev_datasets = dev_datasets  # + [('aida-train', conll.train)]
        dev_datasets = []
        for dname, data in org_dev_datasets:
            dev_datasets.append((dname, ranker.get_data_items(data, predict=True)))
            print(dname, '#dev docs', len(dev_datasets[-1][1]))

        for di, (dname, data) in enumerate(dev_datasets):
            predictions = ranker.predict(data)
            print(dname, utils.tokgreen('micro F1: ' + str(D.eval(org_dev_datasets[di][1], predictions))))

Exemplo n.º 4
0
    with open(F1_CSV_Path, 'w') as f_csv_f1:
        f1_csv_writer = csv.writer(f_csv_f1)
        f1_csv_writer.writerow(['dataset', 'epoch', 'dynamic', 'F1 Score'])

    if args.mode == 'train':
        print('training...')
        config = {
            'lr': args.learning_rate,
            'n_epochs': args.n_epochs,
            'isDynamic': args.isDynamic,
            'use_early_stop': args.use_early_stop,
        }
        # pprint(config)
        ranker.train(conll.train, dev_datasets, config)

    elif args.mode == 'eval':
        org_dev_datasets = dev_datasets  # + [('aida-train', conll.train)]
        dev_datasets = []
        for dname, data in org_dev_datasets:
            dev_datasets.append(
                (dname, ranker.get_data_items(data, predict=True)))
            print(dname, '#dev docs', len(dev_datasets[-1][1]))

        for di, (dname, data) in enumerate(dev_datasets):
            predictions = ranker.predict(data)
            print(
                dname,
                utils.tokgreen(
                    'micro F1: ' +
                    str(D.eval(org_dev_datasets[di][1], predictions))))