Exemplo n.º 1
0
def eval_batch(data_iter,
               model,
               eval_instance,
               best_fscore,
               epoch,
               config,
               test=False):
    model.eval()
    # eval time
    eval_acc = Eval()
    eval_PRF = EvalPRF()
    gold_labels = []
    predict_labels = []
    for batch_features in data_iter:
        logit = model(batch_features)
        for id_batch in range(batch_features.batch_length):
            inst = batch_features.inst[id_batch]
            maxId_batch = getMaxindex_batch(logit[id_batch])
            predict_label = []
            for id_word in range(inst.words_size):
                predict_label.append(
                    config.create_alphabet.label_alphabet.from_id(
                        maxId_batch[id_word]))
            gold_labels.append(inst.labels)
            predict_labels.append(predict_label)
    for p_label, g_label in zip(predict_labels, gold_labels):
        eval_PRF.evalPRF(predict_labels=p_label,
                         gold_labels=g_label,
                         eval=eval_instance)
    if eval_acc.gold_num == 0:
        eval_acc.gold_num = 1
    p, r, f = eval_instance.getFscore()
    test_flag = "Test"
    if test is False:
        print()
        test_flag = "Dev"
        if f >= best_fscore.best_dev_fscore:
            best_fscore.best_dev_fscore = f
            best_fscore.best_epoch = epoch
            best_fscore.best_test = True
    if test is True and best_fscore.best_test is True:
        best_fscore.p = p
        best_fscore.r = r
        best_fscore.f = f
    print(
        "{} eval: precision = {:.6f}%  recall = {:.6f}% , f-score = {:.6f}%,  [TAG-ACC = {:.6f}%]"
        .format(test_flag, p, r, f, eval_acc.acc()))
    if test is True:
        print(
            "The Current Best Dev F-score: {:.6f}, Locate on {} Epoch.".format(
                best_fscore.best_dev_fscore, best_fscore.best_epoch))
        print(
            "The Current Best Test Result: precision = {:.6f}%  recall = {:.6f}% , f-score = {:.6f}%"
            .format(best_fscore.p, best_fscore.r, best_fscore.f))
    if test is True:
        best_fscore.best_test = False
class Train(object):
    """
        Train
    """
    def __init__(self, **kwargs):
        """
        :param kwargs:
        Args of data:
            train_iter : train batch data iterator
            dev_iter : dev batch data iterator
            test_iter : test batch data iterator
        Args of train:
            model : nn model
            config : config
        """
        print("Training Start......")
        # for k, v in kwargs.items():
        #     self.__setattr__(k, v)
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.model = kwargs["model"]
        self.config = kwargs["config"]
        self.use_crf = self.config.use_crf
        self.average_batch = self.config.average_batch
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(name=self.config.learning_algorithm,
                                   model=self.model,
                                   lr=self.config.learning_rate,
                                   weight_decay=self.config.weight_decay,
                                   grad_clip=self.config.clip_max_norm)
        self.loss_function = self._loss(
            learning_algorithm=self.config.learning_algorithm,
            label_paddingId=self.config.label_paddingId,
            use_crf=self.use_crf)
        print(self.optimizer)
        print(self.loss_function)
        self.best_score = Best_Result()
        self.train_eval, self.dev_eval, self.test_eval = Eval(), Eval(), Eval()
        self.train_iter_len = len(self.train_iter)

    def _loss(self, learning_algorithm, label_paddingId, use_crf=False):
        if use_crf:
            loss_function = self.model.crf_layer.neg_log_likelihood_loss
            return loss_function
        elif learning_algorithm == "SGD":
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                size_average=False)
            return loss_function
        else:
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                size_average=True)
            return loss_function

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):
        """
        :param clip_max_norm_use:  whether to use clip max norm for nn model
        :param clip_max_norm: clip max norm max values [float or None]
        :return:
        """
        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm(self.model.parameters(), max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):
        """
        :param config:  config
        :param epoch:  epoch
        :param new_lr:  learning rate
        :return:
        """
        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, epoch, init_lr):
        """lr decay 

        Args:
            epoch: int, epoch 
            init_lr:  initial lr
        """
        lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):
        """
        :return:
        """
        if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch):
        """
        :param epoch:
        :return:
        """
        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            print("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                print(
                    "Early Stop Train. Best Score Locate on {} Epoch.".format(
                        self.best_score.best_epoch))
                exit()

    @staticmethod
    def _get_model_args(batch_features):
        """
        :param batch_features:  Batch Instance
        :return:
        """
        word = batch_features.word_features
        char = batch_features.char_features
        mask = word > 0
        sentence_length = batch_features.sentence_length
        # desorted_indices = batch_features.desorted_indices
        tags = batch_features.label_features
        return word, char, mask, sentence_length, tags

    def _calculate_loss(self, feats, mask, tags):
        """
        Args:
            feats: size = (batch_size, seq_len, tag_size)
            mask: size = (batch_size, seq_len)
            tags: size = (batch_size, seq_len)
        """
        if not self.use_crf:
            batch_size, max_len = feats.size(0), feats.size(1)
            lstm_feats = feats.view(batch_size * max_len, -1)
            tags = tags.view(-1)
            return self.loss_function(lstm_feats, tags)
        else:
            loss_value = self.loss_function(feats, mask, tags)
        if self.average_batch:
            batch_size = feats.size(0)
            loss_value /= float(batch_size)
        return loss_value

    def train(self):
        """
        :return:
        """
        epochs = self.config.epochs
        clip_max_norm_use = self.config.clip_max_norm_use
        clip_max_norm = self.config.clip_max_norm
        new_lr = self.config.learning_rate

        for epoch in range(1, epochs + 1):
            print("\n## The {} Epoch, All {} Epochs ! ##".format(
                epoch, epochs))
            # new_lr = self._dynamic_lr(config=self.config, epoch=epoch, new_lr=new_lr)
            self.optimizer = self._decay_learning_rate(
                epoch=epoch - 1, init_lr=self.config.learning_rate)
            print("now lr is {}".format(
                self.optimizer.param_groups[0].get("lr")),
                  end="")
            start_time = time.time()
            random.shuffle(self.train_iter)
            self.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            for batch_count, batch_features in enumerate(self.train_iter):
                backward_count += 1
                # self.optimizer.zero_grad()
                word, char, mask, sentence_length, tags = self._get_model_args(
                    batch_features)
                logit = self.model(word, char, sentence_length, train=True)
                loss = self._calculate_loss(logit, mask, tags)
                loss.backward()
                self._clip_model_norm(clip_max_norm_use, clip_max_norm)
                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)
                # self.optimizer.step()
                steps += 1
                if (steps - 1) % self.config.log_interval == 0:
                    self.getAcc(self.train_eval, batch_features, logit,
                                self.config)
                    sys.stdout.write(
                        "\nbatch_count = [{}] , loss is {:.6f}, [TAG-ACC is {:.6f}%]"
                        .format(batch_count + 1, loss.data[0],
                                self.train_eval.acc()))
            end_time = time.time()
            print("\nTrain Time {:.3f}".format(end_time - start_time), end="")
            self.eval(model=self.model, epoch=epoch, config=self.config)
            self._model2file(model=self.model, config=self.config, epoch=epoch)
            self._early_stop(epoch=epoch)

    def eval(self, model, epoch, config):
        """
        :param model: nn model
        :param epoch:  epoch
        :param config:  config
        :return:
        """
        self.dev_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.dev_iter,
                        model,
                        self.dev_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=False)
        eval_end_time = time.time()
        print("Dev Time {:.3f}".format(eval_end_time - eval_start_time))

        self.test_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.test_iter,
                        model,
                        self.test_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=True)
        eval_end_time = time.time()
        print("Test Time {:.3f}".format(eval_end_time - eval_start_time))

    def _model2file(self, model, config, epoch):
        """
        :param model:  nn model
        :param config:  config
        :param epoch:  epoch
        :return:
        """
        if config.save_model and config.save_all_model:
            save_model_all(model, config.save_dir, config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config.save_best_model_path,
                            config.model_name, self.best_score)
        else:
            print()

    def eval_batch(self,
                   data_iter,
                   model,
                   eval_instance,
                   best_score,
                   epoch,
                   config,
                   test=False):
        """
        :param data_iter:  eval batch data iterator
        :param model: eval model
        :param eval_instance:
        :param best_score:
        :param epoch:
        :param config: config
        :param test:  whether to test
        :return: None
        """
        model.eval()
        # eval time
        eval_acc = Eval()
        eval_PRF = EvalPRF()
        gold_labels = []
        predict_labels = []
        for batch_features in data_iter:
            word, char, mask, sentence_length, tags = self._get_model_args(
                batch_features)
            logit = model(word, char, sentence_length, train=False)
            if self.use_crf is False:
                predict_ids = torch_max(logit)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    label_ids = predict_ids[id_batch]
                    predict_label = []
                    for id_word in range(inst.words_size):
                        predict_label.append(
                            config.create_alphabet.label_alphabet.from_id(
                                label_ids[id_word]))
                    gold_labels.append(inst.labels)
                    predict_labels.append(predict_label)
            else:
                path_score, best_paths = model.crf_layer(logit, mask)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    gold_labels.append(inst.labels)
                    label_ids = best_paths[id_batch].cpu().data.numpy(
                    )[:inst.words_size]
                    label = []
                    for i in label_ids:
                        label.append(
                            config.create_alphabet.label_alphabet.from_id(i))
                    predict_labels.append(label)
        for p_label, g_label in zip(predict_labels, gold_labels):
            eval_PRF.evalPRF(predict_labels=p_label,
                             gold_labels=g_label,
                             eval=eval_instance)
        if eval_acc.gold_num == 0:
            eval_acc.gold_num = 1
        p, r, f = eval_instance.getFscore()
        # p, r, f = entity_evalPRF_exact(gold_labels=gold_labels, predict_labels=predict_labels)
        # p, r, f = entity_evalPRF_propor(gold_labels=gold_labels, predict_labels=predict_labels)
        # p, r, f = entity_evalPRF_binary(gold_labels=gold_labels, predict_labels=predict_labels)
        test_flag = "Test"
        if test is False:
            print()
            test_flag = "Dev"
            best_score.current_dev_score = f
            if f >= best_score.best_dev_score:
                best_score.best_dev_score = f
                best_score.best_epoch = epoch
                best_score.best_test = True
        if test is True and best_score.best_test is True:
            best_score.p = p
            best_score.r = r
            best_score.f = f
        print(
            "{} eval: precision = {:.6f}%  recall = {:.6f}% , f-score = {:.6f}%,  [TAG-ACC = {:.6f}%]"
            .format(test_flag, p, r, f, 0.0000))
        if test is True:
            print("The Current Best Dev F-score: {:.6f}, Locate on {} Epoch.".
                  format(best_score.best_dev_score, best_score.best_epoch))
            print(
                "The Current Best Test Result: precision = {:.6f}%  recall = {:.6f}% , f-score = {:.6f}%"
                .format(best_score.p, best_score.r, best_score.f))
        if test is True:
            best_score.best_test = False

    @staticmethod
    def getAcc(eval_acc, batch_features, logit, config):
        """
        :param eval_acc:  eval instance
        :param batch_features:  batch data feature
        :param logit:  model output
        :param config:  config
        :return:
        """
        eval_acc.clear_PRF()
        predict_ids = torch_max(logit)
        for id_batch in range(batch_features.batch_length):
            inst = batch_features.inst[id_batch]
            label_ids = predict_ids[id_batch]
            predict_label = []
            gold_lable = inst.labels
            for id_word in range(inst.words_size):
                predict_label.append(
                    config.create_alphabet.label_alphabet.from_id(
                        label_ids[id_word]))
            assert len(predict_label) == len(gold_lable)
            cor = 0
            for p_lable, g_lable in zip(predict_label, gold_lable):
                if p_lable == g_lable:
                    cor += 1
            eval_acc.correct_num += cor
            eval_acc.gold_num += len(gold_lable)
Exemplo n.º 3
0
def train(train_iter, dev_iter, test_iter, model, config):
    if config.use_cuda:
        model.cuda()

    optimizer = None
    if config.adam is True:
        print("Adam Training......")
        if config.embed_finetune is True:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=config.learning_rate,
                                         weight_decay=config.weight_decay)
        else:
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                model.parameters()),
                                         lr=config.learning_rate,
                                         weight_decay=config.weight_decay)

    if config.sgd is True:
        print("SGD Training......")
        if config.embed_finetune is True:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=config.learning_rate,
                                        weight_decay=config.weight_decay)
        else:
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                               model.parameters()),
                                        lr=config.learning_rate,
                                        weight_decay=config.weight_decay)

    best_fscore = Best_Result()

    steps = 0
    model_count = 0
    model.train()
    max_dev_acc = -1
    train_eval = Eval()
    dev_eval = Eval()
    test_eval = Eval()
    for epoch in range(1, config.epochs + 1):
        print("\n## The {} Epoch, All {} Epochs ! ##".format(
            epoch, config.epochs))
        print("now lr is {}".format(optimizer.param_groups[0].get("lr")))
        start_time = time.time()
        random.shuffle(train_iter)
        model.train()
        steps = 0
        for batch_count, batch_features in enumerate(train_iter):
            model.zero_grad()
            optimizer.zero_grad()
            # if config.use_cuda is True:
            #     batch_features.label_features = batch_features.label_features.cuda()
            logit = model(batch_features)
            # loss_logit = logit.view(logit.size(0) * logit.size(1), logit.size(2))
            loss = F.cross_entropy(logit.view(
                logit.size(0) * logit.size(1), -1),
                                   batch_features.label_features,
                                   ignore_index=config.label_paddingId)
            loss.backward()
            # if config.clip_max_norm is not None:
            #     utils.clip_grad_norm(model.parameters(), max_norm=config.clip_max_norm)
            optimizer.step()
            steps += 1
            if steps % config.log_interval == 0:
                getAcc(train_eval, batch_features, logit, config)
                sys.stdout.write(
                    "\nbatch_count = [{}] , loss is {:.6f}, [TAG-ACC is {:.6f}%]"
                    .format(batch_count + 1, loss.data[0], train_eval.acc()))

        end_time = time.time()
        print("\nTrain Time {:.3f}".format(end_time - start_time), end="")
        if steps is not 0:
            dev_eval.clear_PRF()
            eval_start_time = time.time()
            # eval(dev_iter, model, dev_eval, best_fscore, epoch, config, test=False)
            eval_batch(dev_iter,
                       model,
                       dev_eval,
                       best_fscore,
                       epoch,
                       config,
                       test=False)
            eval_end_time = time.time()
            print("Dev Time {:.3f}".format(eval_end_time - eval_start_time))
            # model.train()
        if steps is not 0:
            test_eval.clear_PRF()
            eval_start_time = time.time()
            # eval(test_iter, model, test_eval, best_fscore, epoch, config, test=True)
            eval_batch(test_iter,
                       model,
                       test_eval,
                       best_fscore,
                       epoch,
                       config,
                       test=True)
            eval_end_time = time.time()
            print("Test Time {:.3f}".format(eval_end_time - eval_start_time))
Exemplo n.º 4
0
class Train(object):
    """
        Train
    """
    def __init__(self, **kwargs):
        """
        :param kwargs:
        Args of data:
            train_iter : train batch data iterator
            dev_iter : dev batch data iterator
            test_iter : test batch data iterator
        Args of train:
            model : nn model
            config : config
        """
        print("Training Start......")
        # for k, v in kwargs.items():
        #     self.__setattr__(k, v)
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.model = kwargs["model"]
        self.config = kwargs["config"]
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(name=self.config.learning_algorithm,
                                   model=self.model,
                                   lr=self.config.learning_rate,
                                   weight_decay=self.config.weight_decay,
                                   grad_clip=self.config.clip_max_norm)
        if self.config.learning_algorithm == "SGD":
            self.loss_function = nn.CrossEntropyLoss(size_average=False)
        else:
            self.loss_function = nn.CrossEntropyLoss(size_average=True)
        print(self.optimizer)
        self.best_score = Best_Result()
        self.train_eval, self.dev_eval_seg, self.dev_eval_pos, self.test_eval_seg, self.test_eval_pos = Eval(
        ), Eval(), Eval(), Eval(), Eval()
        self.train_iter_len = len(self.train_iter)

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):
        """
        :param clip_max_norm_use:  whether to use clip max norm for nn model
        :param clip_max_norm: clip max norm max values [float or None]
        :return:
        """
        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm(self.model.parameters(), max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):
        """
        :param config:  config
        :param epoch:  epoch
        :param new_lr:  learning rate
        :return:
        """
        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            # print("epoch", epoch)
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, epoch, init_lr):
        """衰减学习率

        Args:
            epoch: int, 迭代次数
            init_lr: 初始学习率
        """
        lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
        # print('learning rate: {0}'.format(lr))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):
        """
        :return:
        """
        if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch):
        """
        :param epoch:
        :return:
        """
        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            print("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                print(
                    "Early Stop Train. Best Score Locate on {} Epoch.".format(
                        self.best_score.best_epoch))
                exit()

    def _model2file(self, model, config, epoch):
        """
        :param model:  nn model
        :param config:  config
        :param epoch:  epoch
        :return:
        """
        if config.save_model and config.save_all_model:
            save_model_all(model, config.save_dir, config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config.save_best_model_path,
                            config.model_name, self.best_score)
        else:
            print()

    def train(self):
        """
        :return:
        """
        epochs = self.config.epochs
        clip_max_norm_use = self.config.clip_max_norm_use
        clip_max_norm = self.config.clip_max_norm
        new_lr = self.config.learning_rate

        for epoch in range(1, epochs + 1):
            print("\n## The {} Epoch, All {} Epochs ! ##".format(
                epoch, epochs))
            new_lr = self._dynamic_lr(config=self.config,
                                      epoch=epoch,
                                      new_lr=new_lr)
            # self.optimizer = self._decay_learning_rate(epoch=epoch - 1, init_lr=self.config.learning_rate)
            print("now lr is {}".format(
                self.optimizer.param_groups[0].get("lr")),
                  end="")
            start_time = time.time()
            random.shuffle(self.train_iter)
            self.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            for batch_count, batch_features in enumerate(self.train_iter):
                backward_count += 1
                # self.optimizer.zero_grad()
                maxCharSize = batch_features.char_features.size()[1]
                decoder_out, state = self.model(batch_features, train=True)
                self.cal_train_acc(batch_features, self.train_eval,
                                   batch_count, decoder_out, maxCharSize,
                                   self.config)
                loss = torch.nn.functional.nll_loss(
                    decoder_out, batch_features.gold_features)
                loss.backward()
                self._clip_model_norm(clip_max_norm_use, clip_max_norm)
                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)
                # self.optimizer.step()
                steps += 1
                if (steps - 1) % self.config.log_interval == 0:
                    sys.stdout.write(
                        "\nBatch_count = [{}/{}] , Loss is {:.6f} , (Correct/Total_num) = Accuracy ({}/{})"
                        " = {:.6f}%".format(batch_count + 1,
                                            self.train_iter_len, loss.data[0],
                                            self.train_eval.correct_num,
                                            self.train_eval.gold_num,
                                            self.train_eval.acc() * 100))
            end_time = time.time()
            # print("\nTrain Time {:.3f}".format(end_time - start_time), end="")
            print("\nTrain Time {:.4f}".format(end_time - start_time))
            self.eval(model=self.model, epoch=epoch, config=self.config)
            self._model2file(model=self.model, config=self.config, epoch=epoch)
            self._early_stop(epoch=epoch)

    def eval(self, model, epoch, config):
        """
        :param model: nn model
        :param epoch:  epoch
        :param config:  config
        :return:
        """
        self.dev_eval_pos.clear()
        self.dev_eval_seg.clear()
        eval_start_time = time.time()
        self.eval_batch(self.dev_iter,
                        model,
                        self.dev_eval_seg,
                        self.dev_eval_pos,
                        self.best_score,
                        epoch,
                        config,
                        test=False)
        eval_end_time = time.time()
        print("Dev Time {:.4f}".format(eval_end_time - eval_start_time))

        self.test_eval_pos.clear()
        self.test_eval_seg.clear()
        eval_start_time = time.time()
        self.eval_batch(self.test_iter,
                        model,
                        self.test_eval_seg,
                        self.test_eval_pos,
                        self.best_score,
                        epoch,
                        config,
                        test=True)
        eval_end_time = time.time()
        print("Test Time {:.4f}".format(eval_end_time - eval_start_time))

    def eval_batch(self,
                   data_iter,
                   model,
                   eval_seg,
                   eval_pos,
                   best_score,
                   epoch,
                   config,
                   test=False):
        """
        :param data_iter:  eval data iterator
        :param model:  nn model
        :param eval_seg:  seg eval
        :param eval_pos:  pos eval
        :param best_score:  best score
        :param epoch:  current epoch
        :param config:  config
        :param test:  test
        :return:
        """
        model.eval()
        for batch_features in data_iter:
            decoder_out, state = model(batch_features, train=False)
            for i in range(batch_features.batch_length):
                self.jointPRF_Batch(batch_features.inst[i], state.words[i],
                                    state.pos_labels[i], eval_seg, eval_pos)

        # calculate the F-Score
        seg_p, seg_r, seg_f = eval_seg.getFscore()
        pos_p, pos_r, pos_f = eval_pos.getFscore()

        test_flag = "Test"
        if test is False:
            # print()
            test_flag = "Dev"
            best_score.current_dev_score = pos_f
            if pos_f >= best_score.best_dev_score:
                best_score.best_dev_score = pos_f
                best_score.best_epoch = epoch
                best_score.best_test = True
        if test is True and best_score.best_test is True:
            best_score.p = pos_p
            best_score.r = pos_r
            best_score.f = pos_f

        print(test_flag + " ---->")
        print("seg: precision = {:.4f}%  recall = {:.4f}% , f-score = {:.4f}%".
              format(seg_p, seg_r, seg_f))
        print("pos: precision = {:.4f}%  recall = {:.4f}% , f-score = {:.4f}%".
              format(pos_p, pos_r, pos_f))

        if test is True:
            print("The Current Best Dev F-score: {:.4f}%, Locate on {} Epoch.".
                  format(best_score.best_dev_score, best_score.best_epoch))
        if test is True:
            best_score.best_test = False

    @staticmethod
    def jointPRF_Batch(inst, state_words, state_posLabel, seg_eval, pos_eval):
        """
        :param inst:
        :param state_words:
        :param state_posLabel:
        :param seg_eval:
        :param pos_eval:
        :return:
        """
        words = state_words
        posLabels = state_posLabel
        count = 0
        predict_seg = []
        predict_pos = []

        for idx in range(len(words)):
            w = words[idx]
            posLabel = posLabels[idx]
            predict_seg.append('[' + str(count) + ',' + str(count + len(w)) +
                               ']')
            predict_pos.append('[' + str(count) + ',' + str(count + len(w)) +
                               ']' + posLabel)
            count += len(w)

        seg_eval.gold_num += len(inst.gold_seg)
        seg_eval.predict_num += len(predict_seg)
        for p in predict_seg:
            if p in inst.gold_seg:
                seg_eval.correct_num += 1

        pos_eval.gold_num += len(inst.gold_pos)
        pos_eval.predict_num += len(predict_pos)
        for p in predict_pos:
            if p in inst.gold_pos:
                pos_eval.correct_num += 1

    def cal_train_acc(self, batch_features, train_eval, batch_count,
                      decoder_out, maxCharSize, args):
        """
        :param batch_features:
        :param train_eval:
        :param batch_count:
        :param decoder_out:
        :param maxCharSize:
        :param args:
        :return:
        """
        train_eval.clear()
        for id_batch in range(batch_features.batch_length):
            inst = batch_features.inst[id_batch]
            for id_char in range(inst.chars_size):
                actionID = self.getMaxindex(
                    decoder_out[id_batch * maxCharSize + id_char], args)
                if actionID == inst.gold_index[id_char]:
                    train_eval.correct_num += 1
            train_eval.gold_num += inst.chars_size

    @staticmethod
    def getMaxindex(decode_out_acc, config):
        """
        :param decode_out_acc:
        :param config:
        :return:
        """
        max = decode_out_acc.data[0]
        maxIndex = 0
        for idx in range(1, config.label_size):
            if decode_out_acc.data[idx] > max:
                max = decode_out_acc.data[idx]
                maxIndex = idx
        return maxIndex
Exemplo n.º 5
0
class Train(object):
    def __init__(self, **kwargs):

        self.config = kwargs["config"]
        self.config.logger.info("Training Start......")
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.model = kwargs["model"]
        self.use_crf = self.config.use_crf
        self.average_batch = self.config.average_batch
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(name=self.config.learning_algorithm,
                                   model=self.model,
                                   lr=self.config.learning_rate,
                                   weight_decay=self.config.weight_decay,
                                   grad_clip=self.config.clip_max_norm)
        self.loss_function = self._loss(
            learning_algorithm=self.config.learning_algorithm,
            label_paddingId=self.config.label_paddingId,
            use_crf=self.use_crf)
        self.config.logger.info(self.optimizer)
        self.config.logger.info(self.loss_function)
        self.best_score = Best_Result()
        self.train_eval, self.dev_eval, self.test_eval = Eval(), Eval(), Eval()
        self.train_iter_len = len(self.train_iter)

    def _loss(self, learning_algorithm, label_paddingId, use_crf=False):

        if use_crf:
            loss_function = self.model.crf_layer.neg_log_likelihood_loss
            return loss_function
        elif learning_algorithm == "SGD":
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                reduction="sum")
            return loss_function
        else:
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                reduction="mean")
            return loss_function

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):

        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):

        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, epoch, init_lr):

        lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):

        if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch, config):

        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            self.config.logger.info("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                self.end_of_epoch = epoch
                self.config.logger.info(
                    "\n\nEarly Stop Train. Best Score Locate on {} Epoch.".
                    format(self.best_score.best_epoch))
                self.save_training_summary()
                exit()

    @staticmethod
    def _get_model_args(batch_features):

        word = batch_features.word_features
        mask = word > 0
        sentence_length = batch_features.sentence_length
        tags = batch_features.label_features
        return word, mask, sentence_length, tags

    def _calculate_loss(self, feats, mask, tags):

        if not self.use_crf:
            batch_size, max_len = feats.size(0), feats.size(1)
            lstm_feats = feats.view(batch_size * max_len, -1)
            tags = tags.view(-1)
            return self.loss_function(lstm_feats, tags)
        else:
            loss_value = self.loss_function(feats, mask, tags)
        if self.average_batch:
            batch_size = feats.size(0)
            loss_value /= float(batch_size)
        return loss_value

    def train(self):

        epochs = self.config.epochs
        clip_max_norm_use = self.config.clip_max_norm_use
        clip_max_norm = self.config.clip_max_norm
        new_lr = self.config.learning_rate
        self.config.logger.info('\n\n')
        self.config.logger.info('=-' * 50)
        self.config.logger.info('batch number: %d' % len(self.train_iter))

        for epoch in range(1, epochs + 1):
            self.config.logger.info("\n\n### Epoch: {}/{} ###".format(
                epoch, epochs))
            self.optimizer = self._decay_learning_rate(
                epoch=epoch - 1, init_lr=self.config.learning_rate)
            self.config.logger.info("current lr: {}".format(
                self.optimizer.param_groups[0].get("lr")))
            start_time = time.time()
            random.shuffle(self.train_iter)
            self.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            self.config.logger.info('=-' * 10)
            for batch_count, batch_features in enumerate(self.train_iter):
                backward_count += 1
                word, mask, sentence_length, tags = self._get_model_args(
                    batch_features)
                logit = self.model(word, sentence_length, train=True)
                loss = self._calculate_loss(logit, mask, tags)
                loss.backward()
                self._clip_model_norm(clip_max_norm_use, clip_max_norm)
                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)
                steps += 1
                if (steps - 1) % self.config.log_interval == 0:
                    self.getAcc(self.train_eval, batch_features, logit,
                                self.config)
                    self.config.logger.info(
                        "batch_count:{} , loss: {:.4f}, [TAG-ACC: {:.4f}%]".
                        format(batch_count + 1, loss.item(),
                               self.train_eval.acc()))
            end_time = time.time()
            self.config.logger.info("Train Time {:.3f}".format(end_time -
                                                               start_time))
            self.config.logger.info('=-' * 10)
            self.eval(model=self.model, epoch=epoch, config=self.config)
            self.config.logger.info('=-' * 10)
            self._model2file(model=self.model, config=self.config, epoch=epoch)
            self._early_stop(epoch=epoch, config=self.config)
            self.config.logger.info('=-' * 15)
        self.save_training_summary()

    def save_training_summary(self):
        self.config.logger.info(
            "Copy the last model ckps to {} as backup.".format(
                self.config.save_dir))
        shutil.copytree(
            self.config.save_model_dir, "/".join(
                [self.config.save_dir, self.config.save_model_dir + "_bak"]))

        self.config.logger.info(
            "save the training summary at end of the log file.")
        self.config.logger.info("\n")
        self.config.logger.info("*" * 25)

        par_path = os.path.dirname(self.config.train_file)
        self.config.logger.info("dataset:\n\t %s" % par_path)
        self.config.logger.info("\ttrain set count: %d" %
                                self.config.train_cnt)
        self.config.logger.info("\tdev set count: %d" % self.config.dev_cnt)
        self.config.logger.info("\ttest set count: %d" % self.config.test_cnt)

        self.config.logger.info("*" * 10)
        self.config.logger.info("model:")
        self.config.logger.info(self.model)

        self.config.logger.info("*" * 10)
        self.config.logger.info("training:")
        self.config.logger.info('\tbatch size: %d' % self.config.batch_size)
        self.config.logger.info('\tbatch count: %d' % len(self.train_iter))

        self.config.logger.info("*" * 10)
        self.config.logger.info("best performance:")
        self.config.logger.info("\tend at epoch: %d" % self.end_of_epoch)
        self.config.logger.info("\tbest at epoch: %d" %
                                self.best_score.best_epoch)
        self.config.logger.info("\tdev(%):")
        self.config.logger.info("\t\tprecision, %.5f" %
                                self.best_score.best_dev_p_score)
        self.config.logger.info("\t\trecall, %.5f" %
                                self.best_score.best_dev_r_score)
        self.config.logger.info("\t\tf1, %.5f" %
                                self.best_score.best_dev_f1_score)
        self.config.logger.info("\ttest(%):")
        self.config.logger.info("\t\tprecision, %.5f" % self.best_score.p)
        self.config.logger.info("\t\trecall, %.5f" % self.best_score.r)
        self.config.logger.info("\t\tf1, %.5f" % self.best_score.f)

        self.config.logger.info("*" * 25)

    def eval(self, model, epoch, config):

        self.dev_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.dev_iter,
                        model,
                        self.dev_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=False)
        eval_end_time = time.time()
        self.config.logger.info("Dev Time: {:.3f}".format(eval_end_time -
                                                          eval_start_time))
        self.config.logger.info('=-' * 10)

        self.test_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.test_iter,
                        model,
                        self.test_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=True)
        eval_end_time = time.time()
        self.config.logger.info("Test Time: {:.3f}".format(eval_end_time -
                                                           eval_start_time))

    def _model2file(self, model, config, epoch):

        if config.save_model and config.save_all_model:
            save_model_all(model, config, config.save_model_dir,
                           config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config, config.save_model_dir,
                            config.model_name, self.best_score)
        else:
            self.config.logger.info()

    def eval_batch(self,
                   data_iter,
                   model,
                   eval_instance,
                   best_score,
                   epoch,
                   config,
                   test=False):

        test_flag = "Test"
        if test is False:  # dev
            test_flag = "Dev"

        model.eval()  # set flag for pytorch
        eval_PRF = EvalPRF()
        gold_labels = []
        predict_labels = []
        for batch_features in data_iter:
            word, mask, sentence_length, tags = self._get_model_args(
                batch_features)
            logit = model(word, sentence_length, train=False)

            if self.use_crf is False:
                predict_ids = torch_max(logit)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    label_ids = predict_ids[id_batch]
                    predict_label = []
                    for id_word in range(inst.words_size):
                        predict_label.append(
                            config.create_alphabet.label_alphabet.from_id(
                                label_ids[id_word]))
                    gold_labels.append(inst.labels)
                    predict_labels.append(predict_label)
            else:
                path_score, best_paths = model.crf_layer(logit, mask)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    gold_labels.append(inst.labels)
                    label_ids = best_paths[id_batch].cpu().data.numpy(
                    )[:inst.words_size]
                    label = []
                    for i in label_ids:
                        # self.config.logger.info("\n", i)
                        label.append(
                            config.create_alphabet.label_alphabet.from_id(
                                int(i)))
                    predict_labels.append(label)

        for p_label, g_label in zip(predict_labels, gold_labels):
            eval_PRF.evalPRF(predict_labels=p_label,
                             gold_labels=g_label,
                             eval=eval_instance)

        cor = 0
        totol_leng = sum(
            [len(predict_label) for predict_label in predict_labels])
        for p_lable, g_lable in zip(predict_labels, gold_labels):
            for p_lable_, g_lable_ in zip(p_lable, g_lable):
                if p_lable_ == g_lable_:
                    cor += 1
        acc_ = cor / totol_leng * 100

        p, r, f = eval_instance.getFscore()

        if test is False:  # dev
            best_score.current_dev_score = f
            if f >= best_score.best_dev_f1_score:
                best_score.best_dev_f1_score = f
                best_score.best_dev_p_score = p
                best_score.best_dev_r_score = r
                best_score.best_epoch = epoch
                best_score.best_test = True
        if test is True and best_score.best_test is True:  # test
            best_score.p = p
            best_score.r = r
            best_score.f = f
        self.config.logger.info(
            "{} at current epoch, precision: {:.4f}%  recall: {:.4f}% , f-score: {:.4f}%,  [TAG-ACC: {:.3f}%]"
            .format(test_flag, p, r, f, acc_))
        if test is False:
            self.config.logger.info(
                "Till now, The Best Dev Result: precision: {:.4f}%  recall: {:.4f}% , f-score: {:.4f}%, Locate on {} Epoch."
                .format(best_score.best_dev_p_score,
                        best_score.best_dev_r_score,
                        best_score.best_dev_f1_score, best_score.best_epoch))
        elif test is True:
            self.config.logger.info(
                "Till now, The Best Test Result: precision: {:.4f}%  recall: {:.4f}% , f-score: {:.4f}%, Locate on {} Epoch."
                .format(best_score.p, best_score.r, best_score.f,
                        best_score.best_epoch))
            best_score.best_test = False

    @staticmethod
    def getAcc(eval_acc, batch_features, logit, config):

        eval_acc.clear_PRF()
        predict_ids = torch_max(logit)
        for id_batch in range(batch_features.batch_length):
            inst = batch_features.inst[id_batch]
            label_ids = predict_ids[id_batch]
            predict_label = []
            gold_lable = inst.labels
            for id_word in range(inst.words_size):
                predict_label.append(
                    config.create_alphabet.label_alphabet.from_id(
                        label_ids[id_word]))
            assert len(predict_label) == len(gold_lable)
            cor = 0
            for p_lable, g_lable in zip(predict_label, gold_lable):
                if p_lable == g_lable:
                    cor += 1
            eval_acc.correct_num += cor
            eval_acc.gold_num += len(gold_lable)