Esempio n. 1
0
    def train_eval(self):
        train_loader = DataLoader(self.train_set,
                                  batch_size=self.args.batch_size,
                                  shuffle=True)
        self.args.max_step = self.args.epoch * (len(train_loader) //
                                                self.args.update_step)
        print('max step:', self.args.max_step)
        optimizer = Optimizer(
            filter(lambda p: p.requires_grad, self.model.parameters()), args)
        best_dev_metric, best_test_metric = dict(), dict()
        patient = 0
        for ep in range(1, 1 + self.args.epoch):
            train_loss = 0.
            self.model.train()
            t1 = time.time()
            train_head_acc, train_rel_acc, train_total_head = 0, 0, 0
            for i, batcher in enumerate(train_loader):
                batch = batch_variable(batcher, self.vocabs)
                batch.to_device(self.args.device)

                head_score, rel_score = self.model(batch.wd_ids, batch.ch_ids,
                                                   batch.tag_ids)
                loss = self.calc_loss(head_score, rel_score, batch.head_ids,
                                      batch.rel_ids, batch.wd_ids.gt(0))
                loss_val = loss.data.item()
                train_loss += loss_val

                head_acc, rel_acc, total_head = self.calc_acc(
                    head_score, rel_score, batch.head_ids, batch.rel_ids)
                train_head_acc += head_acc
                train_rel_acc += rel_acc
                train_total_head += total_head

                if self.args.update_step > 1:
                    loss = loss / self.args.update_step

                loss.backward()

                if (i + 1) % self.args.update_step == 0 or (
                        i == self.args.max_step - 1):
                    nn_utils.clip_grad_norm_(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             max_norm=self.args.grad_clip)
                    optimizer.step()
                    self.model.zero_grad()

                logger.info(
                    '[Epoch %d] Iter%d time cost: %.2fs, lr: %.6f, train loss: %.3f, head acc: %.3f, rel acc: %.3f'
                    % (ep, i + 1, (time.time() - t1), optimizer.get_lr(),
                       loss_val, train_head_acc / train_total_head,
                       train_rel_acc / train_total_head))

            dev_metric = self.evaluate('dev')
            if dev_metric['uf'] > best_dev_metric.get('uf', 0):
                best_dev_metric = dev_metric
                test_metric = self.evaluate('test')
                if test_metric['uf'] > best_test_metric.get('uf', 0):
                    # check_point = {'model': self.model.state_dict(), 'settings': args}
                    # torch.save(check_point, self.args.model_chkp)
                    best_test_metric = test_metric
                patient = 0
            else:
                patient += 1

            logger.info(
                '[Epoch %d] train loss: %.4f, lr: %f, patient: %d, dev_metric: %s, test_metric: %s'
                % (ep, train_loss, optimizer.get_lr(), patient,
                   best_dev_metric, best_test_metric))

            # if patient == (self.args.patient // 2 + 1):  # 训练一定epoch, dev性能不上升, decay lr
            #     optimizer.lr_decay(0.95)

            if patient >= self.args.patient:  # early stopping
                break

        logger.info('Final Metric: %s' % best_test_metric)
Esempio n. 2
0
    def train_eval(self):
        train_loader = DataLoader(
            self.train_set, batch_size=self.args.batch_size, shuffle=True)
        self.args.max_step = self.args.epoch * \
            (len(train_loader) // self.args.update_step)
        print('max step:', self.args.max_step)
        optimizer = Optimizer(
            filter(lambda p: p.requires_grad, self.model.parameters()), args)
        best_dev_metric, best_test_metric = dict(), dict()
        patient = 0
        for ep in range(1, 1 + self.args.epoch):
            train_loss = 0.
            self.model.train()
            t1 = time.time()
            train_right, train_pred, train_gold = 0, 0, 0
            for i, batcher in enumerate(train_loader):
                batch = batch_variable(batcher, self.vocabs)
                batch.to_device(self.args.device)

                pred_score = self.model(
                    batch.wd_ids, batch.ch_ids, batch.tag_ids, batch.bert_inps)
                loss = self.calc_loss(pred_score, batch.ner_ids)
                loss_val = loss.data.item()
                train_loss += loss_val

                sent_lens = batch.wd_ids.gt(0).sum(dim=1)
                gold_res = self.ner_gold(
                    batch.ner_ids, sent_lens, self.vocabs['ner'])
                pred_res = self.ner_pred(
                    pred_score, sent_lens, self.vocabs['ner'])
                nb_right, nb_pred, nb_gold = self.calc_acc(
                    pred_res, gold_res, return_prf=False)
                train_right += nb_right
                train_pred += nb_pred
                train_gold += nb_gold
                train_p, train_r, train_f = self.calc_prf(
                    train_right, train_pred, train_gold)

                if self.args.update_step > 1:
                    loss = loss / self.args.update_step

                loss.backward()

                if (i + 1) % self.args.update_step == 0 or (i == self.args.max_step - 1):
                    nn_utils.clip_grad_norm_(filter(lambda p: p.requires_grad, self.model.parameters()),
                                             max_norm=self.args.grad_clip)
                    optimizer.step()
                    self.model.zero_grad()

                logger.info('[Epoch %d] Iter%d time cost: %.2fs, lr: %.6f, train loss: %.3f, P: %.3f, R: %.3f, F: %.3f' % (
                    ep, i + 1, (time.time() - t1), optimizer.get_lr(), loss_val, train_p, train_r, train_f))

            dev_metric = self.evaluate('dev')
            if dev_metric['f'] > best_dev_metric.get('f', 0):
                best_dev_metric = dev_metric
                test_metric = self.evaluate('test')
                if test_metric['f'] > best_test_metric.get('f', 0):
                    # check_point = {'model': self.model.state_dict(), 'settings': args}
                    # torch.save(check_point, self.args.model_chkp)
                    best_test_metric = test_metric
                patient = 0
            else:
                patient += 1

            logger.info('[Epoch %d] train loss: %.4f, lr: %f, patient: %d, dev_metric: %s, test_metric: %s' % (
                ep, train_loss, optimizer.get_lr(), patient, best_dev_metric, best_test_metric))

            # if patient >= (self.args.patient // 2 + 1):  # 训练一定epoch, dev性能不上升, decay lr
            #     optimizer.lr_decay(0.95)

            if patient >= self.args.patient:  # early stopping
                break

        logger.info('Final Metric: %s' % best_test_metric)
Esempio n. 3
0
def train(model, train_data, dev_data, test_data, args, word_vocab,
          extwd_vocab, lbl_vocab):
    args.max_step = args.epoch * ((len(train_data) + args.batch_size - 1) //
                                  (args.batch_size * args.update_steps))
    optimizer = Optimizer(
        filter(lambda p: p.requires_grad, lni_model.parameters()), args)
    best_dev_acc, best_test_acc = 0, 0
    patient = 0
    for ep in range(1, 1 + args.epoch):
        model.train()
        train_loss = 0.
        start_time = time.time()
        for i, batch_data in enumerate(
                batch_iter(train_data, args.batch_size, True)):
            batcher = batch_variable(batch_data, word_vocab, extwd_vocab,
                                     lbl_vocab)
            batcher = (x.to(args.device) for x in batcher)
            sent1, sent2, extsent1, extsent2, gold_lbl = batcher
            pred = model((sent1, sent2), (extsent1, extsent2))
            loss = criterion(pred, gold_lbl)
            if args.update_steps > 1:
                loss = loss / args.update_steps

            loss_val = loss.data.item()
            train_loss += loss_val

            loss.backward()

            if (i + 1) % args.update_steps == 0 or (i == args.max_step - 1):
                nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad,
                                                model.parameters()),
                                         max_norm=args.grad_clip)
                optimizer.step()
                model.zero_grad()

            train_acc = calc_acc(pred, gold_lbl) / len(batch_data)
            logger.info(
                'Iter%d time cost: %.2fs, lr: %.8f, train loss: %.3f, train acc: %.3f'
                % (i + 1, (time.time() - start_time), optimizer.get_lr(),
                   loss_val, train_acc))

        train_loss /= len(train_data)
        dev_acc = eval(model, dev_data, args, word_vocab, extwd_vocab,
                       lbl_vocab)
        logger.info('[Epoch %d] train loss: %.3f, lr: %f, DEV ACC: %.3f' %
                    (ep, train_loss, optimizer.get_lr(), dev_acc))

        if dev_acc > best_dev_acc:
            patient = 0
            best_dev_acc = dev_acc
            test_acc = eval(model, test_data, args, word_vocab, extwd_vocab,
                            lbl_vocab)
            logger.info('Test ACC: %.3f' % test_acc)
            if test_acc > best_test_acc:
                best_test_acc = test_acc
        else:
            patient += 1

        if patient > args.patient:
            break

    logger.info('Final Test ACC: %.3f' % best_test_acc)