コード例 #1
0
    def train_iter(self, train_data, args, optimizer, *vocab):
        self.parser_model.train()
        train_loss = 0
        all_arc_acc, all_rel_acc, all_arcs = 0, 0, 0
        start_time = time.time()
        for i, batch_data in enumerate(
                batch_iter(train_data, args.batch_size, True)):
            batcher = batch_variable(batch_data, *vocab, args.device)
            # batcher = (x.to(args.device) for x in batcher)
            ngram_idxs, extngram_idxs, true_tags, true_heads, true_rels, non_pad_mask = batcher

            tag_score, arc_score, rel_score = self.parser_model(
                ngram_idxs, extngram_idxs, mask=non_pad_mask)

            tag_loss = self.parser_model.tag_loss(tag_score, true_tags,
                                                  non_pad_mask)
            # tag_loss = self.calc_tag_loss(tag_score, true_tags, non_pad_mask)
            dep_loss = self.calc_dep_loss(arc_score, rel_score, true_heads,
                                          true_rels, non_pad_mask)
            loss = tag_loss + dep_loss
            if args.update_steps > 1:
                loss = loss / args.update_steps
            loss_val = loss.float().item()
            train_loss += loss_val
            loss.backward()

            arc_acc, rel_acc, nb_arcs = self.calc_acc(arc_score, rel_score,
                                                      true_heads, true_rels,
                                                      non_pad_mask)
            all_arc_acc += arc_acc
            all_rel_acc += rel_acc
            all_arcs += nb_arcs
            ARC = all_arc_acc * 100. / all_arcs
            REL = all_rel_acc * 100. / all_arcs

            # 多次循环,梯度累积,相对于变相增大batch_size,节省存储
            if (i + 1) % args.update_steps == 0 or (i == args.max_step - 1):
                nn.utils.clip_grad_norm_(filter(
                    lambda p: p.requires_grad, self.parser_model.parameters()),
                                         max_norm=args.grad_clip)
                optimizer.step()  # 利用梯度更新网络参数
                self.parser_model.zero_grad()  # 清空过往梯度

            logger.info('Iter%d ARC: %.2f%%, REL: %.2f%%' % (i + 1, ARC, REL))
            logger.info(
                'time cost: %.2fs, lr: %.8f train loss: %.2f' %
                ((time.time() - start_time), optimizer.get_lr(), loss_val))

        train_loss /= len(train_data)
        ARC = all_arc_acc * 100. / all_arcs
        REL = all_rel_acc * 100. / all_arcs

        return train_loss, ARC, REL
コード例 #2
0
ファイル: parser.py プロジェクト: LindgeW/BiaffineParser
    def train_iter(self, train_data, args, vocab, optimizer):
        self.parser_model.train()

        train_loss = 0
        all_arc_acc, all_rel_acc, all_arcs = 0, 0, 0
        start_time = time.time()
        nb_batch = int(np.ceil(len(train_data) / args.batch_size))
        batch_size = args.batch_size // args.update_steps
        for i, batcher in enumerate(
                batch_iter(train_data, batch_size, vocab, True)):
            batcher = (x.to(args.device) for x in batcher)
            wd_idx, extwd_idx, tag_idx, true_head_idx, true_rel_idx, non_pad_mask, _ = batcher

            pred_arc_score, pred_rel_score = self.parser_model(
                wd_idx, extwd_idx, tag_idx, non_pad_mask)

            loss = self.calc_loss(pred_arc_score, pred_rel_score,
                                  true_head_idx, true_rel_idx, non_pad_mask)
            if args.update_steps > 1:
                loss = loss / args.update_steps
            loss_val = loss.data.item()
            train_loss += loss_val
            loss.backward()

            arc_acc, rel_acc, total_arcs = self.calc_acc(
                pred_arc_score, pred_rel_score, true_head_idx, true_rel_idx,
                non_pad_mask)
            all_arc_acc += arc_acc
            all_rel_acc += rel_acc
            all_arcs += total_arcs

            ARC = all_arc_acc * 100. / all_arcs
            REL = all_rel_acc * 100. / all_arcs
            logger.info('Iter%d ARC: %.3f%%, REL: %.3f%%' % (i + 1, ARC, REL))
            logger.info('time cost: %.2fs, train loss: %.2f' %
                        ((time.time() - start_time), loss_val))

            # 梯度累积,相对于变相增大batch_size,节省存储
            if (i + 1) % args.update_steps == 0 or (i == nb_batch - 1):
                nn.utils.clip_grad_norm_(filter(
                    lambda p: p.requires_grad, self.parser_model.parameters()),
                                         max_norm=5.)
                optimizer.step()
                self.parser_model.zero_grad()

        train_loss /= len(train_data)
        ARC = all_arc_acc * 100. / all_arcs
        REL = all_rel_acc * 100. / all_arcs

        return train_loss, ARC, REL
コード例 #3
0
ファイル: parser.py プロジェクト: LindgeW/BiaffineParser
    def train(self, train_data, dev_data, test_data, args, vocab):
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      self.parser_model.parameters()),
                               lr=args.learning_rate,
                               betas=(args.beta1, args.beta2))
        lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                   lambda t: 0.75**(t / 5000))
        # 当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能
        # lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=5, verbose=True, min_lr=1-5)
        # lr_scheduler.step(val_loss)
        best_uas = 0
        test_best_uas, test_best_las = 0, 0
        for ep in range(1, 1 + args.epoch):
            train_loss, arc, rel = self.train_iter(train_data, args, vocab,
                                                   optimizer)
            dev_uas, dev_las = self.evaluate(dev_data, args, vocab)
            lr_scheduler.step()  # 每个epoch改变一次lr
            logger.info(
                '[Epoch %d] train loss: %.3f, lr: %f, ARC: %.3f%%, REL: %.3f%%'
                % (ep, train_loss, lr_scheduler.get_lr()[0], arc, rel))
            logger.info(
                'Dev data -- UAS: %.3f%%, LAS: %.3f%%, best_UAS: %.3f%%' %
                (dev_uas, dev_las, best_uas))
            if dev_uas > best_uas:
                best_uas = dev_uas
                test_uas, test_las = self.evaluate(test_data, args, vocab)
                if test_best_uas < test_uas:
                    test_best_uas = test_uas
                if test_best_las < test_las:
                    test_best_las = test_las

                logger.info('Test data -- UAS: %.3f%%, LAS: %.3f%%' %
                            (test_uas, test_las))

        logger.info('Final test performance -- UAS: %.3f%%, LAS: %.3f%%' %
                    (test_best_uas, test_best_las))
コード例 #4
0
ファイル: parser.py プロジェクト: LindgeW/BiaffineParser
 def summary(self):
     logger.info(self.parser_model)
コード例 #5
0
    def train(self, train_data, dev_data, test_data, args, *vocab):
        args.max_step = args.epoch * (
            (len(train_data) + args.batch_size - 1) //
            (args.batch_size * args.update_steps))
        args.warmup_step = args.max_step // 2
        print('max step:', args.max_step)
        optimizer = Optimizer(
            filter(lambda p: p.requires_grad, self.parser_model.parameters()),
            args)
        best_udep_f1 = 0
        test_best_uas, test_best_las = 0, 0
        test_best_tag_f1, test_best_seg_f1, test_best_udep_f1, test_best_ldep_f1 = 0, 0, 0, 0
        for ep in range(1, 1 + args.epoch):
            train_loss, arc, rel = self.train_iter(train_data, args, optimizer,
                                                   *vocab)
            dev_uas, dev_las, tag_f1, seg_f1, udep_f1, ldep_f1 = self.evaluate(
                dev_data, args, *vocab)
            logger.info(
                '[Epoch %d] train loss: %.3f, lr: %f, ARC: %.2f%%, REL: %.2f%%'
                % (ep, train_loss, optimizer.get_lr(), arc, rel))
            logger.info(
                'Dev data -- UAS: %.2f%%, LAS: %.2f%%, best_UAS: %.2f%%' %
                (dev_uas, dev_las, best_udep_f1))
            logger.info(
                'Dev data -- TAG: %.2f%%, Seg F1: %.2f%%, UDEP F1: %.2f%%, LDEP F1: %.2f%%'
                % (tag_f1, seg_f1, udep_f1, ldep_f1))

            if udep_f1 > best_udep_f1:
                best_udep_f1 = udep_f1
                test_uas, test_las, test_tag_f1, test_seg_f1, test_udep_f1, test_ldep_f1 = self.evaluate(
                    test_data, args, *vocab)
                if test_best_uas < test_uas:
                    test_best_uas = test_uas
                if test_best_las < test_las:
                    test_best_las = test_las
                if test_best_tag_f1 < test_tag_f1:
                    test_best_tag_f1 = test_tag_f1
                if test_best_seg_f1 < test_seg_f1:
                    test_best_seg_f1 = test_seg_f1
                if test_best_udep_f1 < test_udep_f1:
                    test_best_udep_f1 = test_udep_f1
                if test_best_ldep_f1 < test_ldep_f1:
                    test_best_ldep_f1 = test_ldep_f1

                logger.info('Test data -- UAS: %.2f%%, LAS: %.2f%%' %
                            (test_uas, test_las))
                logger.info(
                    'Test data -- Tag F1: %.2f%%, Seg F1: %.2f%%, UDEP F1: %.2f%%, LDEP F1: %.2f%%'
                    % (test_tag_f1, test_seg_f1, test_udep_f1, test_ldep_f1))
                print('tag scale: ', self.parser_model.scale_tag.get_params())
                print('dep scale: ', self.parser_model.scale_dep.get_params())

        logger.info('Final test performance -- UAS: %.2f%%, LAS: %.2f%%' %
                    (test_best_uas, test_best_las))
        logger.info(
            'Final test performance -- Tag F1: %.2f%%, Seg F1: %.2f%%, UDEP F1: %.2f%%, LDEP F1: %.2f%%'
            % (test_best_tag_f1, test_best_seg_f1, test_best_udep_f1,
               test_best_ldep_f1))
コード例 #6
0
    def train(self, train_data, dev_data, test_data, args, vocab):
        args.max_step = args.epoch * ((len(train_data) + args.batch_size - 1) // (args.batch_size*args.update_steps))
        # args.warmup_step = args.max_step // 2
        print('max step:', args.max_step)
        optimizer = Optimizer(filter(lambda p: p.requires_grad, self.parser_model.model.parameters()), args)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_bert_parameters = [
            {'params': [p for n, p in self.parser_model.bert.named_parameters()
                        if not any(nd in n for nd in no_decay) and p.requires_grad],
             'weight_decay': 0.01},
            {'params': [p for n, p in self.parser_model.bert.named_parameters()
                        if any(nd in n for nd in no_decay) and p.requires_grad],
             'weight_decay': 0.0}
        ]
        optimizer_bert = AdamW(optimizer_bert_parameters, lr=5e-5, eps=1e-8)
        scheduler_bert = WarmupLinearSchedule(optimizer_bert, warmup_steps=0, t_total=args.max_step)
        all_params = [p for p in self.parser_model.model.parameters() if p.requires_grad]
        for group in optimizer_bert_parameters:
            for p in group['params']:
                all_params.append(p)

        test_best_uas, test_best_las = 0, 0
        test_best_tag_f1, test_best_seg_f1, test_best_udep_f1, test_best_ldep_f1 = 0, 0, 0, 0
        for ep in range(1, 1+args.epoch):
            self.parser_model.model.train()
            self.parser_model.bert.train()
            train_loss = 0
            all_arc_acc, all_rel_acc, all_arcs = 0, 0, 0
            start_time = time.time()
            for i, batch_data in enumerate(batch_iter(train_data, args.batch_size, True)):
                batcher = batch_variable(batch_data, vocab, args.device)
                # batcher = (x.to(args.device) for x in batcher)
                (bert_ids, bert_lens, bert_mask), true_tags, true_heads, true_rels = batcher
                tag_score, arc_score, rel_score = self.parser_model(bert_ids, bert_lens, bert_mask)

                tag_loss = self.calc_tag_loss(tag_score, true_tags, bert_lens.gt(0))
                dep_loss = self.calc_dep_loss(arc_score, rel_score, true_heads, true_rels, bert_lens.gt(0))
                loss = tag_loss + dep_loss
                if args.update_steps > 1:
                    loss = loss / args.update_steps
                loss_val = loss.data.item()
                train_loss += loss_val
                loss.backward()  # 反向传播,计算当前梯度

                arc_acc, rel_acc, nb_arcs = self.calc_acc(arc_score, rel_score, true_heads, true_rels, bert_lens.gt(0))
                all_arc_acc += arc_acc
                all_rel_acc += rel_acc
                all_arcs += nb_arcs
                ARC = all_arc_acc * 100. / all_arcs
                REL = all_rel_acc * 100. / all_arcs

                if (i + 1) % args.update_steps == 0 or (i == args.max_step - 1):
                    nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, all_params), max_norm=args.grad_clip)
                    optimizer.step()  # 利用梯度更新网络参数
                    optimizer_bert.step()
                    scheduler_bert.step()
                    self.parser_model.model.zero_grad()  # 清空过往梯度
                    self.parser_model.bert.zero_grad()  # 清空过往梯度

                logger.info('Iter%d ARC: %.2f%%, REL: %.2f%%' % (i + 1, ARC, REL))
                logger.info('time cost: %.2fs, lr: %f train loss: %.2f' % (
                (time.time() - start_time), optimizer.get_lr(), loss_val))

            train_loss /= len(train_data)
            arc = all_arc_acc * 100. / all_arcs
            rel = all_rel_acc * 100. / all_arcs

            dev_uas, dev_las, tag_f1, seg_f1, udep_f1, ldep_f1 = self.evaluate(dev_data, args, vocab)
            logger.info('[Epoch %d] train loss: %.3f, lr: %f, ARC: %.2f%%, REL: %.2f%%' % (ep, train_loss, optimizer.get_lr(), arc, rel))
            logger.info('Dev data -- UAS: %.2f%%, LAS: %.2f%%' % (100.*dev_uas, 100.*dev_las))
            logger.info('Dev data -- TAG: %.2f%%, Seg F1: %.2f%%, UDEP F1: %.2f%%, LDEP F1: %.2f%%' % (100.*tag_f1, 100.*seg_f1, 100.*udep_f1, 100.*ldep_f1))

            # with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f:
            #     torch.save(self.parser_model, f)
            test_uas, test_las, test_tag_f1, test_seg_f1, test_udep_f1, test_ldep_f1 = self.evaluate(test_data, args, vocab)
            if test_best_uas < test_uas:
                test_best_uas = test_uas
            if test_best_las < test_las:
                test_best_las = test_las
            if test_best_tag_f1 < test_tag_f1:
                test_best_tag_f1 = test_tag_f1
            if test_best_seg_f1 < test_seg_f1:
                test_best_seg_f1 = test_seg_f1
            if test_best_udep_f1 < test_udep_f1:
                test_best_udep_f1 = test_udep_f1
            if test_best_ldep_f1 < test_ldep_f1:
                test_best_ldep_f1 = test_ldep_f1

            logger.info('Test data -- UAS: %.2f%%, LAS: %.2f%%' % (100.*test_uas, 100.*test_las))
            logger.info('Test data -- Tag F1: %.2f%%, Seg F1: %.2f%%, UDEP F1: %.2f%%, LDEP F1: %.2f%%' % (100.*test_tag_f1, 100.*test_seg_f1, 100.*test_udep_f1, 100.*test_ldep_f1))

        logger.info('Final test performance -- UAS: %.2f%%, LAS: %.2f%%' % (100.*test_best_uas, 100.*test_best_las))
        logger.info('Final test performance -- Tag F1: %.2f%%, Seg F1: %.2f%%, UDEP F1: %.2f%%, LDEP F1: %.2f%%' % (100.*test_best_tag_f1, 100.*test_best_seg_f1, 100.*test_best_udep_f1, 100.*test_best_ldep_f1))