Example #1
0
    def train(self, crf2train_dataloader, crf2dev_dataloader,
              dev_dataset_loader, epoch_list, args):
        start_time = time.time()

        for epoch_idx in epoch_list:
            args.start_epoch = epoch_idx
            curr_start_time = time.time()
            ###########################
            crf_no = random.randint(0, len(self.crf2corpus) - 1)
            ###########################

            cur_dataset = crf2train_dataloader[crf_no]
            epoch_loss = self.train_epoch(cur_dataset, crf_no, self.crit_ner,
                                          self.optimizer, args)

            # main evaluation on the combined dev in N21 or single dev in N2N
            corpus_name = [
                args.dev_file[i].split("/")[-2]
                for i in self.crf2corpus[crf_no]
            ]
            print(args.dispatch, "Dev Corpus: ", corpus_name)

            dev_f1, dev_pre, dev_rec, dev_acc = self.eval_epoch(
                crf2dev_dataloader[crf_no], crf_no, args)

            if_add_patience = True
            if dev_f1 > self.best_f1[crf_no] and not args.combine:
                print("Prev Best F1: {:.4f} Curr Best F1: {:.4f}".format(
                    self.best_f1[crf_no], dev_f1))
                self.best_epoch_idx = epoch_idx
                self.patience_count = 0
                self.best_f1[crf_no] = dev_f1
                self.best_pre[crf_no] = dev_pre
                self.best_rec[crf_no] = dev_rec
                self.best_state_dict = deepcopy(self.ner_model.state_dict())

                checkpoint_name = args.checkpoint + "/"
                checkpoint_name += args.dispatch + "_"
                if args.dispatch in ["N2K", "N2N"]:
                    checkpoint_name += args.train_file[
                        self.crf2corpus[crf_no][0]].split("/")[-2] + "_"
                checkpoint_name += "{:.4f}_{:.4f}_{:.4f}_{:d}".format(
                    dev_f1, dev_pre, dev_rec, epoch_idx)

                print("NOW SAVING, ", checkpoint_name)
                print()

                self.drop_check_point(checkpoint_name, args)
                self.best_checkpoint_name = checkpoint_name

                if_add_patience &= False
            else:
                if args.dispatch == "N2N" or not args.stop_on_single:
                    self.patience_count += 1
            self.track_list.append({
                'loss': epoch_loss,
                'dev_f1': dev_f1,
                'dev_acc': dev_acc
            })

            if epoch_idx == args.epoch - 1:
                last_checkpoint_name = args.checkpoint + "/"
                last_checkpoint_name += args.dispatch + "_"
                if args.dispatch in ["N2K", "N2N"]:
                    last_checkpoint_name += args.train_file[
                        self.crf2corpus[crf_no][0]].split("/")[-2] + "_"
                last_checkpoint_name += "LAST" + "_"
                last_checkpoint_name += "{:.4f}_{:.4f}_{:.4f}_{:d}".format(
                    dev_f1, dev_pre, dev_rec, epoch_idx)

                print("NOW SAVING LAST, ", last_checkpoint_name)
                self.drop_check_point(last_checkpoint_name, args)
                print()

                if args.combine:
                    self.best_state_dict = deepcopy(
                        self.ner_model.state_dict())

            # save check point for each corpus
            if args.dispatch in ["N21", "N2K"]:

                # print("Drop the best check point for single corpus")

                for cid in self.crf2corpus[crf_no]:
                    print(args.dev_file[cid])
                    cid_f1, cid_pre, cid_rec, cid_acc = self.eval_epoch(
                        dev_dataset_loader[cid], crf_no, args)
                    # F1
                    if cid_f1 > self.corpus_best_vec[cid][
                            0] and not args.combine:
                        print(
                            "Prev Best F1: {:.4f} Curr Best F1: {:.4f}".format(
                                self.corpus_best_vec[cid][0], cid_f1))
                        self.corpus_best_vec[cid] = [cid_f1, cid_pre, cid_rec]

                        if args.stop_on_single:
                            self.patience_count = 0

                        checkpoint_name = args.checkpoint + "/"
                        checkpoint_name += args.dispatch + "_"
                        checkpoint_name += args.dev_file[cid].split(
                            "/")[-2] + "_"
                        checkpoint_name += "{:.4f}_{:.4f}_{:.4f}_{:d}".format(
                            cid_f1, cid_pre, cid_rec, epoch_idx)

                        print("NOW SAVING, ", checkpoint_name)
                        self.drop_check_point(checkpoint_name, args)
                        print()
                        self.corpus_best_checkpoint_name[cid] = checkpoint_name

                        if_add_patience &= False
                    else:
                        if_add_patience &= True
                if if_add_patience and args.stop_on_single:
                    self.patience_count += 1

            operating_time = time.time() - start_time
            h = operating_time // 3600
            m = (operating_time - 3600 * h) // 60
            s = operating_time - 3600 * h - 60 * m

            print(
                "Epoch: [{:d}/{:d}]\t Patient: {:d}\t Current: {:.2f}\t Total: {:2d}:{:2d}:{:.2f}\n"
                .format(args.start_epoch, args.epoch - 1, self.patience_count,
                        time.time() - curr_start_time, int(h), int(m), s))
            if self.patience_count >= args.patience and args.start_epoch >= args.least_iters:
                break

            # update lr
            if self.plateau:
                self.scheduler.step(dev_f1)
            else:
                utils.adjust_learning_rate(
                    self.optimizer,
                    args.lr / (1 + (args.start_epoch + 1) * args.lr_decay))

        print("Sample Frequence")
        for crf, corpus_idx in self.crf2corpus.items():
            corpus_name = [
                args.train_file[i].split("/")[-2] for i in corpus_idx
            ]
            print(crf, corpus_name, self.sample_cnter[crf])
        print()
Example #2
0
            f_f, f_p, b_f, b_p, w_f, tg_v, mask_v, SCRF_labels, mask_SCRF_labels, cnn_features = packer.repack(f_f, f_p, b_f, b_p, w_f, tg_v, mask_v, len_v, SCRF_labels, mask_SCRF_labels, cnn_features, test=False)
            optimizer.zero_grad()

            loss = model(f_f, f_p, b_f, b_p, w_f, cnn_features, tg_v, mask_v,
                         mask_v.long().sum(0), SCRF_labels, mask_SCRF_labels, onlycrf=False)

            epoch_loss += utils.to_scalar(loss)
            loss.backward()
            nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)
            optimizer.step()

        epoch_loss /= tot_length
        print('epoch_loss: ', epoch_loss)

        utils.adjust_learning_rate(optimizer, args.lr / (1 + (args.start_epoch + 1) * args.lr_decay))


        dev_f1_crf, dev_pre_crf, dev_rec_crf, dev_acc_crf, dev_f1_scrf, dev_pre_scrf, dev_rec_scrf, dev_acc_scrf, dev_f1_jnt, dev_pre_jnt, dev_rec_jnt, dev_acc_jnt = \
                evaluator.calc_score(model, dev_dataset_loader)

        if dev_f1_jnt > best_dev_f1_jnt:
            early_stop_epochs = 0
            test_f1_crf, test_pre_crf, test_rec_crf, test_acc_crf, test_f1_scrf, test_pre_scrf, test_rec_scrf, test_acc_scrf, test_f1_jnt, test_pre_jnt, test_rec_jnt, test_acc_jnt = \
                        evaluator.calc_score(model, test_dataset_loader)

            best_test_f1_crf = test_f1_crf
            best_test_f1_scrf = test_f1_scrf

            best_dev_f1_jnt = dev_f1_jnt
            best_test_f1_jnt = test_f1_jnt
Example #3
0
    def train(self, data, *args, **kwargs):
        tot_length = sum(map(lambda t: len(t), self.dataset_loader))
        loss_list = []
        acc_list = []
        best_f1 = []
        for i in range(self.file_num):
            best_f1.append(float('-inf'))

        best_pre = []
        for i in range(self.file_num):
            best_pre.append(float('-inf'))

        best_rec = []
        for i in range(self.file_num):
            best_rec.append(float('-inf'))

        start_time = time.time()
        epoch_list = range(self.args.start_epoch,
                           self.args.start_epoch + self.args.epoch)
        patience_count = 0
        for epoch_idx, self.args.start_epoch in enumerate(epoch_list):

            sample_num = 1

            epoch_loss = 0
            self.ner_model.train()

            for sample_id in tqdm(range(sample_num),
                                  mininterval=2,
                                  desc=' - Tot it %d (epoch %d)' %
                                  (tot_length, self.args.start_epoch),
                                  leave=False,
                                  file=sys.stdout):

                self.file_no = random.randint(0, self.file_num - 1)
                cur_dataset = self.dataset_loader[self.file_no]

                for f_f, f_p, b_f, b_p, w_f, tg_v, mask_v, len_v in itertools.chain.from_iterable(
                        cur_dataset):

                    f_f, f_p, b_f, b_p, w_f, tg_v, mask_v = self.packer.repack_vb(
                        f_f, f_p, b_f, b_p, w_f, tg_v, mask_v, len_v)

                    self.ner_model.zero_grad()
                    scores = self.ner_model(f_f, f_p, b_f, b_p, w_f,
                                            self.file_no)
                    loss = self.crit_ner(scores, tg_v, mask_v)

                    epoch_loss += utils.to_scalar(loss)
                    if self.args.co_train:
                        cf_p = f_p[0:-1, :].contiguous()
                        cb_p = b_p[1:, :].contiguous()
                        cf_y = w_f[1:, :].contiguous()
                        cb_y = w_f[0:-1, :].contiguous()
                        cfs, _ = self.ner_model.word_pre_train_forward(
                            f_f, cf_p)
                        loss = loss + self.args.lambda0 * self.crit_lm(
                            cfs, cf_y.view(-1))
                        cbs, _ = self.ner_model.word_pre_train_backward(
                            b_f, cb_p)
                        loss = loss + self.args.lambda0 * self.crit_lm(
                            cbs, cb_y.view(-1))
                    loss.backward()
                    nn.utils.clip_grad_norm(self.ner_model.parameters(),
                                            self.args.clip_grad)
                    self.optimizer.step()

            epoch_loss /= tot_length

            # update lr
            utils.adjust_learning_rate(
                self.optimizer, self.args.lr /
                (1 + (self.args.start_epoch + 1) * self.args.lr_decay))

            # eval & save check_point
            if 'f' in self.args.eva_matrix:
                dev_f1, dev_pre, dev_rec, dev_acc = self.evaluate(
                    None, None, self.dev_dataset_loader[self.file_no],
                    self.file_no)
                loss_list.append(epoch_loss)
                acc_list.append(dev_acc)
                if dev_f1 > best_f1[self.file_no]:
                    patience_count = 0
                    best_f1[self.file_no] = dev_f1
                    best_pre[self.file_no] = dev_pre
                    best_rec[self.file_no] = dev_rec
                    self.track_list.append({
                        'loss': epoch_loss,
                        'dev_f1': dev_f1,
                        'dev_acc': dev_acc
                    })
                    print(
                        '(loss: %.4f, epoch: %d, dataset: %d, dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f)'
                        % (epoch_loss, self.args.start_epoch, self.file_no,
                           dev_f1, dev_pre, dev_rec))
                    try:
                        self.save_model(None)
                    except Exception as inst:
                        print(inst)

                else:
                    patience_count += 1
                    print(
                        '(loss: %.4f, epoch: %d, dataset: %d, dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f)'
                        % (epoch_loss, self.args.start_epoch, self.file_no,
                           dev_f1, dev_pre, dev_rec))
                    self.track_list.append({
                        'loss': epoch_loss,
                        'dev_f1': dev_f1,
                        'dev_acc': dev_acc
                    })

            print('epoch: ' + str(self.args.start_epoch) + '\t in ' +
                  str(self.args.epoch) + ' take: ' +
                  str(time.time() - start_time) + ' s')

            if patience_count >= self.args.patience and self.args.start_epoch >= self.args.least_iters:
                break
        return loss_list, acc_list