Esempio n. 1
0
    def train_iter(self, ep, train_set, optimizer):
        t1 = time.time()
        train_acc, train_loss = 0., 0.
        train_loader = DataLoader(train_set,
                                  batch_size=self.args.batch_size,
                                  shuffle=True)
        self.model.train()
        for i, batcher in enumerate(train_loader):
            batch = batch_variable(batcher, self.vocabs)
            batch.to_device(self.args.device)
            pred = self.model(batch.x, batch.nx, batch.ew)
            loss = F.nll_loss(pred, batch.y)
            loss.backward()
            nn_utils.clip_grad_norm_(filter(lambda p: p.requires_grad,
                                            self.model.parameters()),
                                     max_norm=args.grad_clip)
            optimizer.step()
            self.model.zero_grad()

            loss_val = loss.data.item()
            train_loss += loss_val
            train_acc += (pred.data.argmax(dim=-1) == batch.y).sum().item()

            logger.info(
                '[Epoch %d] Iter%d time cost: %.2fs, lr: %.6f, train acc: %.4f, train loss: %.4f'
                % (ep, i + 1, (time.time() - t1), optimizer.get_lr(),
                   train_acc / len(train_set), loss_val))

        return train_loss / len(train_set), train_acc / len(train_set)
Esempio n. 2
0
    def evaluate(self, mode='test'):
        if mode == 'dev':
            test_loader = DataLoader(
                self.val_set, batch_size=self.args.test_batch_size)
        elif mode == 'test':
            test_loader = DataLoader(
                self.test_set, batch_size=self.args.test_batch_size)
        else:
            raise ValueError('Invalid Mode!!!')

        self.model.eval()
        nb_right_all, nb_pred_all, nb_gold_all = 0, 0, 0
        with torch.no_grad():
            for i, batcher in enumerate(test_loader):
                batch = batch_variable(batcher, self.vocabs)
                batch.to_device(self.args.device)

                pred_score = self.model(
                    batch.wd_ids, batch.ch_ids, batch.tag_ids, batch.bert_inps)
                sent_lens = batch.wd_ids.gt(0).sum(dim=1)
                gold_res = self.ner_gold(
                    batch.ner_ids, sent_lens, self.vocabs['ner'])
                pred_res = self.ner_pred(
                    pred_score, sent_lens, self.vocabs['ner'])
                nb_right, nb_pred, nb_gold = self.calc_acc(
                    pred_res, gold_res, return_prf=False)
                nb_right_all += nb_right
                nb_pred_all += nb_pred
                nb_gold_all += nb_gold
        p, r, f = self.calc_prf(nb_right_all, nb_pred_all, nb_gold_all)
        return dict(p=p, r=r, f=f)
Esempio n. 3
0
 def eval(self, test_set):
     nb_correct, nb_total = 0, 0
     test_loader = DataLoader(test_set,
                              batch_size=self.args.test_batch_size)
     self.model.eval()
     with torch.no_grad():
         for i, batcher in enumerate(test_loader):
             batch = batch_variable(batcher, self.vocabs)
             batch.to_device(self.args.device)
             pred = self.model(batch.x, batch.nx, batch.ew)
             nb_correct += (pred.data.argmax(
                 dim=-1) == batch.y).sum().item()
             nb_total += len(batch.y)
     return nb_correct / nb_total
Esempio n. 4
0
 def eval(self, task_id, test_data):
     print(f'evaluating {get_task(task_id)} task ...')
     nb_correct, nb_total = 0, 0
     self.model.eval()
     test_loader = DataLoader(test_data,
                              batch_size=self.args.test_batch_size)
     with torch.no_grad():
         for i, batcher in enumerate(test_loader):
             batch = batch_variable(batcher, self.wd_vocab)
             batch.to_device(self.args.device)
             task_logits, share_logits, _ = self.model(
                 task_id, batch.wd_ids)
             nb_correct += (task_logits.data.argmax(
                 dim=-1) == batch.lbl_ids).sum().item()
             nb_total += len(batch.lbl_ids)
     acc = nb_correct / nb_total
     # err = 1 - acc
     return acc
Esempio n. 5
0
    def train_iter(self, ep, task_id, train_data, optimizer):
        t1 = time.time()
        train_acc, train_loss = 0., 0.
        self.model.train()
        train_loader = DataLoader(train_data,
                                  batch_size=self.args.batch_size,
                                  shuffle=True)
        total_step = 200 * len(train_loader)
        step = 0
        for i, batcher in enumerate(train_loader):
            batch = batch_variable(batcher, self.wd_vocab)
            batch.to_device(self.args.device)
            adv_lmbd = self.lambda_(step, total_step)
            task_logits, share_logits, diff_loss = self.model(
                task_id, batch.wd_ids, adv_lmbd)
            loss_task = F.cross_entropy(task_logits, batch.lbl_ids)
            loss_share = F.cross_entropy(share_logits, batch.task_ids)
            loss = loss_task + self.args.adv_loss_w * loss_share + self.args.diff_loss_w * diff_loss
            loss.backward()
            nn_utils.clip_grad_norm_(filter(lambda p: p.requires_grad,
                                            self.model.parameters()),
                                     max_norm=args.grad_clip)
            optimizer.step()
            self.model.zero_grad()

            loss_val = loss.data.item()
            train_loss += loss_val
            train_acc += (task_logits.data.argmax(
                dim=-1) == batch.lbl_ids).sum().item()
            logger.info(
                '[Epoch %d][Task %s] Iter%d time cost: %.2fs, lr: %.6f, train acc: %.4f, train loss: %.4f'
                % (ep, get_task(task_id), i + 1, (time.time() - t1),
                   optimizer.get_lr(), train_acc / len(train_data), loss_val))

            step += 1

        return train_loss / len(train_data), train_acc / len(train_data)
Esempio n. 6
0
    def train_eval(self):
        train_loader = DataLoader(
            self.train_set, batch_size=self.args.batch_size, shuffle=True)
        self.args.max_step = self.args.epoch * \
            (len(train_loader) // self.args.update_step)
        print('max step:', self.args.max_step)
        optimizer = Optimizer(
            filter(lambda p: p.requires_grad, self.model.parameters()), args)
        best_dev_metric, best_test_metric = dict(), dict()
        patient = 0
        for ep in range(1, 1 + self.args.epoch):
            train_loss = 0.
            self.model.train()
            t1 = time.time()
            train_right, train_pred, train_gold = 0, 0, 0
            for i, batcher in enumerate(train_loader):
                batch = batch_variable(batcher, self.vocabs)
                batch.to_device(self.args.device)

                pred_score = self.model(
                    batch.wd_ids, batch.ch_ids, batch.tag_ids, batch.bert_inps)
                loss = self.calc_loss(pred_score, batch.ner_ids)
                loss_val = loss.data.item()
                train_loss += loss_val

                sent_lens = batch.wd_ids.gt(0).sum(dim=1)
                gold_res = self.ner_gold(
                    batch.ner_ids, sent_lens, self.vocabs['ner'])
                pred_res = self.ner_pred(
                    pred_score, sent_lens, self.vocabs['ner'])
                nb_right, nb_pred, nb_gold = self.calc_acc(
                    pred_res, gold_res, return_prf=False)
                train_right += nb_right
                train_pred += nb_pred
                train_gold += nb_gold
                train_p, train_r, train_f = self.calc_prf(
                    train_right, train_pred, train_gold)

                if self.args.update_step > 1:
                    loss = loss / self.args.update_step

                loss.backward()

                if (i + 1) % self.args.update_step == 0 or (i == self.args.max_step - 1):
                    nn_utils.clip_grad_norm_(filter(lambda p: p.requires_grad, self.model.parameters()),
                                             max_norm=self.args.grad_clip)
                    optimizer.step()
                    self.model.zero_grad()

                logger.info('[Epoch %d] Iter%d time cost: %.2fs, lr: %.6f, train loss: %.3f, P: %.3f, R: %.3f, F: %.3f' % (
                    ep, i + 1, (time.time() - t1), optimizer.get_lr(), loss_val, train_p, train_r, train_f))

            dev_metric = self.evaluate('dev')
            if dev_metric['f'] > best_dev_metric.get('f', 0):
                best_dev_metric = dev_metric
                test_metric = self.evaluate('test')
                if test_metric['f'] > best_test_metric.get('f', 0):
                    # check_point = {'model': self.model.state_dict(), 'settings': args}
                    # torch.save(check_point, self.args.model_chkp)
                    best_test_metric = test_metric
                patient = 0
            else:
                patient += 1

            logger.info('[Epoch %d] train loss: %.4f, lr: %f, patient: %d, dev_metric: %s, test_metric: %s' % (
                ep, train_loss, optimizer.get_lr(), patient, best_dev_metric, best_test_metric))

            # if patient >= (self.args.patient // 2 + 1):  # 训练一定epoch, dev性能不上升, decay lr
            #     optimizer.lr_decay(0.95)

            if patient >= self.args.patient:  # early stopping
                break

        logger.info('Final Metric: %s' % best_test_metric)
Esempio n. 7
0
    def evaluate(self, mode='test'):
        if mode == 'dev':
            test_loader = DataLoader(self.val_set,
                                     batch_size=self.args.test_batch_size)
        elif mode == 'test':
            test_loader = DataLoader(self.test_set,
                                     batch_size=self.args.test_batch_size)
        else:
            raise ValueError('Invalid Mode!!!')

        self.model.eval()
        rel_vocab = self.vocabs['rel']
        nb_head_gold, nb_head_pred, nb_head_correct = 0, 0, 0
        nb_rel_gold, nb_rel_pred, nb_rel_correct = 0, 0, 0
        with torch.no_grad():
            for i, batcher in enumerate(test_loader):
                batch = batch_variable(batcher, self.vocabs)
                batch.to_device(self.args.device)

                head_score, rel_score = self.model(batch.wd_ids, batch.ch_ids,
                                                   batch.tag_ids)
                mask = batch.wd_ids.gt(0)
                lens = mask.sum(dim=1)
                graph_pred = self.model.graph_decode(head_score, rel_score,
                                                     mask)

                pred_deps = self.parse_pred_graph(graph_pred, lens, rel_vocab)
                gold_deps = self.parse_gold_graph(batch.rel_ids, lens,
                                                  rel_vocab)
                assert len(pred_deps) == len(gold_deps)
                # for deps_p, deps_g in zip(pred_deps, gold_deps):
                #     nb_head_gold += len(deps_g)
                #     nb_rel_gold += len(deps_g)
                #
                #     nb_head_pred += len(deps_p)
                #     nb_rel_pred += len(deps_p)
                #     for dg in deps_g:
                #         for dp in deps_p:
                #             if dg[:-1] == dp[:-1]:
                #                 nb_head_correct += 1
                #                 if dg == dp:
                #                     nb_rel_correct += 1
                #                 break

                for pdeps, gdeps in zip(pred_deps, gold_deps):  # sentence
                    assert len(pdeps) == len(gdeps)
                    for pdep, gdep in zip(pdeps, gdeps):  # word
                        nb_head_pred += len(pdep)
                        nb_rel_pred += len(pdep)

                        nb_head_gold += len(gdep)
                        nb_rel_gold += len(gdep)
                        for gd in gdep:  # (head_id, rel_id)
                            for pd in pdep:
                                if pd[0] == gd[0]:
                                    nb_head_correct += 1
                                    if pd == gd:
                                        nb_rel_correct += 1
                                    break

        up, ur, uf = self.calc_prf(nb_head_correct, nb_head_pred, nb_head_gold)
        lp, lr, lf = self.calc_prf(nb_rel_correct, nb_rel_pred, nb_rel_gold)
        return dict(up=up, ur=ur, uf=uf, lp=lp, lr=lr, lf=lf)
Esempio n. 8
0
    def train_eval(self):
        train_loader = DataLoader(self.train_set,
                                  batch_size=self.args.batch_size,
                                  shuffle=True)
        self.args.max_step = self.args.epoch * (len(train_loader) //
                                                self.args.update_step)
        print('max step:', self.args.max_step)
        optimizer = Optimizer(
            filter(lambda p: p.requires_grad, self.model.parameters()), args)
        best_dev_metric, best_test_metric = dict(), dict()
        patient = 0
        for ep in range(1, 1 + self.args.epoch):
            train_loss = 0.
            self.model.train()
            t1 = time.time()
            train_head_acc, train_rel_acc, train_total_head = 0, 0, 0
            for i, batcher in enumerate(train_loader):
                batch = batch_variable(batcher, self.vocabs)
                batch.to_device(self.args.device)

                head_score, rel_score = self.model(batch.wd_ids, batch.ch_ids,
                                                   batch.tag_ids)
                loss = self.calc_loss(head_score, rel_score, batch.head_ids,
                                      batch.rel_ids, batch.wd_ids.gt(0))
                loss_val = loss.data.item()
                train_loss += loss_val

                head_acc, rel_acc, total_head = self.calc_acc(
                    head_score, rel_score, batch.head_ids, batch.rel_ids)
                train_head_acc += head_acc
                train_rel_acc += rel_acc
                train_total_head += total_head

                if self.args.update_step > 1:
                    loss = loss / self.args.update_step

                loss.backward()

                if (i + 1) % self.args.update_step == 0 or (
                        i == self.args.max_step - 1):
                    nn_utils.clip_grad_norm_(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             max_norm=self.args.grad_clip)
                    optimizer.step()
                    self.model.zero_grad()

                logger.info(
                    '[Epoch %d] Iter%d time cost: %.2fs, lr: %.6f, train loss: %.3f, head acc: %.3f, rel acc: %.3f'
                    % (ep, i + 1, (time.time() - t1), optimizer.get_lr(),
                       loss_val, train_head_acc / train_total_head,
                       train_rel_acc / train_total_head))

            dev_metric = self.evaluate('dev')
            if dev_metric['uf'] > best_dev_metric.get('uf', 0):
                best_dev_metric = dev_metric
                test_metric = self.evaluate('test')
                if test_metric['uf'] > best_test_metric.get('uf', 0):
                    # check_point = {'model': self.model.state_dict(), 'settings': args}
                    # torch.save(check_point, self.args.model_chkp)
                    best_test_metric = test_metric
                patient = 0
            else:
                patient += 1

            logger.info(
                '[Epoch %d] train loss: %.4f, lr: %f, patient: %d, dev_metric: %s, test_metric: %s'
                % (ep, train_loss, optimizer.get_lr(), patient,
                   best_dev_metric, best_test_metric))

            # if patient == (self.args.patient // 2 + 1):  # 训练一定epoch, dev性能不上升, decay lr
            #     optimizer.lr_decay(0.95)

            if patient >= self.args.patient:  # early stopping
                break

        logger.info('Final Metric: %s' % best_test_metric)