Ejemplo n.º 1
0
class Train(object):
    """
        Train
    """
    def __init__(self, **kwargs):
        """
        :param kwargs:
        Args of data:
            train_iter : train batch data iterator
            dev_iter : dev batch data iterator
            test_iter : test batch data iterator
        Args of train:
            model : nn model
            config : config
        """
        print("Training Start......")
        # for k, v in kwargs.items():
        #     self.__setattr__(k, v)
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.model = kwargs["model"]
        self.config = kwargs["config"]
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(name=self.config.learning_algorithm,
                                   model=self.model,
                                   lr=self.config.learning_rate,
                                   weight_decay=self.config.weight_decay,
                                   grad_clip=self.config.clip_max_norm)
        self.loss_function = self._loss(
            learning_algorithm=self.config.learning_algorithm)
        print(self.optimizer)
        print(self.loss_function)
        self.best_score = Best_Result()
        self.train_iter_len = len(self.train_iter)

    @staticmethod
    def _loss(learning_algorithm):
        """
        :param learning_algorithm:
        :return:
        """
        if learning_algorithm == "SGD":
            loss_function = nn.CrossEntropyLoss(reduction="sum")
            return loss_function
        else:
            loss_function = nn.CrossEntropyLoss(reduction="mean")
            return loss_function

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):
        """
        :param clip_max_norm_use:  whether to use clip max norm for nn model
        :param clip_max_norm: clip max norm max values [float or None]
        :return:
        """
        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):
        """
        :param config:  config
        :param epoch:  epoch
        :param new_lr:  learning rate
        :return:
        """
        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, config, epoch, init_lr):
        """lr decay 

        Args:
            epoch: int, epoch 
            init_lr:  initial lr
        """
        if config.use_lr_decay:
            lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):
        """
        :return:
        """
        if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch):
        """
        :param epoch:
        :return:
        """
        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            print("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                print(
                    "Early Stop Train. Best Score Locate on {} Epoch.".format(
                        self.best_score.best_epoch))
                exit()

    @staticmethod
    def _get_model_args(batch_features):
        """
        :param batch_features:  Batch Instance
        :return:
        """
        inst = batch_features.inst
        word = batch_features.word_features
        mask = word > 0
        sentence_length = batch_features.sentence_length
        labels = batch_features.label_features
        batch_size = batch_features.batch_length

        return inst, word, mask, sentence_length, labels, batch_size

    def _calculate_loss(self, feats, labels):
        """
        Args:
            feats: size = (batch_size, seq_len, tag_size)
            labels: size = (batch_size, seq_len)
        """
        loss_value = self.loss_function(feats, labels)
        return loss_value

    def train(self):
        """
        :return:
        """
        epochs = self.config.epochs
        clip_max_norm_use = self.config.clip_max_norm_use
        clip_max_norm = self.config.clip_max_norm
        new_lr = self.config.learning_rate

        for epoch in range(1, epochs + 1):
            print("\n## The {} Epoch, All {} Epochs ! ##".format(
                epoch, epochs))
            new_lr = self._dynamic_lr(config=self.config,
                                      epoch=epoch,
                                      new_lr=new_lr)
            # self.optimizer = self._decay_learning_rate(config=self.config, epoch=epoch - 1, init_lr=self.config.learning_rate)
            print("now lr is {}".format(
                self.optimizer.param_groups[0].get("lr")),
                  end="")
            start_time = time.time()
            random.shuffle(self.train_iter)
            self.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            for batch_count, batch_features in enumerate(self.train_iter):
                backward_count += 1
                # self.optimizer.zero_grad()
                inst, word, mask, sentence_length, labels, batch_size = self._get_model_args(
                    batch_features)
                logit = self.model(word, sentence_length, train=True)
                loss = self._calculate_loss(logit, labels)
                loss.backward()
                self._clip_model_norm(clip_max_norm_use, clip_max_norm)
                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)
                # self.optimizer.step()
                steps += 1
                if (steps - 1) % self.config.log_interval == 0:
                    accuracy = self.getAcc(logit, labels, batch_size)
                    sys.stdout.write(
                        "\nbatch_count = [{}] , loss is {:.6f}, [accuracy is {:.6f}%]"
                        .format(batch_count + 1, loss.item(), accuracy))
            end_time = time.time()
            print("\nTrain Time {:.3f}".format(end_time - start_time), end="")
            self.eval(model=self.model, epoch=epoch, config=self.config)
            self._model2file(model=self.model, config=self.config, epoch=epoch)
            self._early_stop(epoch=epoch)

    def eval(self, model, epoch, config):
        """
        :param model: nn model
        :param epoch:  epoch
        :param config:  config
        :return:
        """
        eval_start_time = time.time()
        print('\nmistakes for dev_iter')
        self.eval_batch(self.dev_iter,
                        model,
                        self.best_score,
                        epoch,
                        config,
                        test=False)
        eval_end_time = time.time()
        print("Dev Time {:.3f}".format(eval_end_time - eval_start_time))
        print('mistakes for test_iter')
        eval_start_time = time.time()

        self.eval_batch(self.test_iter,
                        model,
                        self.best_score,
                        epoch,
                        config,
                        test=True)
        eval_end_time = time.time()
        print("Test Time {:.3f}".format(eval_end_time - eval_start_time))

    def _model2file(self, model, config, epoch):
        """
        :param model:  nn model
        :param config:  config
        :param epoch:  epoch
        :return:
        """
        if config.save_model and config.save_all_model:
            save_model_all(model, config.save_dir, config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config.save_best_model_path,
                            config.model_name, self.best_score)
        else:
            print()

    def eval_batch(self,
                   data_iter,
                   model,
                   best_score,
                   epoch,
                   config,
                   test=False):
        """
        :param data_iter:  eval batch data iterator
        :param model: eval model
        :param best_score:
        :param epoch:
        :param config: config
        :param test:  whether to test
        :return: None
        """
        model.eval()
        # eval time
        corrects = 0
        size = 0
        loss = 0
        Truelabel = []
        Words = []
        d = []

        for batch_features in data_iter:
            inst, word, mask, sentence_length, labels, batch_size = self._get_model_args(
                batch_features)

            logit = self.model(word, sentence_length,
                               train=False)  # 加入拼音model需要修改
            loss += self._calculate_loss(logit, labels)
            size += batch_features.batch_length

            t = torch.max(logit, 1)[1].view(labels.size()).data
            p = t.cpu().numpy()
            p = p.tolist()
            for i in p:
                d.append([i])
            for k in inst:
                Truelabel.append(k.label_index)
                Words.append(k.words)

            # torch.max(logit, 1)[1] 返回tensor logit每一行最大值的索引
            corrects += (torch.max(logit, 1)[1].view(
                labels.size()).data == labels.data).sum()

        print("更加详细的评估指标:\n", classification_report(Truelabel, d, digits=5))

        assert size is not 0, print("Error")
        accuracy = float(corrects) / size * 100.0
        average_loss = float(loss) / size

        test_flag = "Test"
        if test is False:
            print()
            test_flag = "Dev"
            best_score.current_dev_score = accuracy
            if accuracy >= best_score.best_dev_score:
                best_score.best_dev_score = accuracy
                best_score.best_epoch = epoch
                best_score.best_test = True
        if test is True and best_score.best_test is True:
            best_score.p = accuracy

        print("{} eval: average_loss = {:.6f}, accuracy = {:.6f}%".format(
            test_flag, average_loss, accuracy))
        if test is True:
            print("The Current Best Dev Accuracy: {:.6f}, Locate on {} Epoch.".
                  format(best_score.best_dev_score, best_score.best_epoch))
            print("The Current Best Test Accuracy: accuracy = {:.6f}%".format(
                best_score.p))
        if test is True:
            best_score.best_test = False

    @staticmethod
    def getAcc(logit, target, batch_size):
        """
        :param logit:  model predict
        :param target:  gold label
        :param batch_size:  batch size
        :param config:  config
        :return:
        """
        corrects = (torch.max(logit, 1)[1].view(
            target.size()).data == target.data).sum()
        accuracy = float(corrects) / batch_size * 100.0
        return accuracy
Ejemplo n.º 2
0
class Train(object):
    """
        Train
    """
    def __init__(self, **kwargs):
        """
        :param kwargs:
        Args of data:
            train_iter : train batch data iterator
            dev_iter : dev batch data iterator
            test_iter : test batch data iterator
        Args of train:
            model : nn model
            config : config
        """
        print("Training Start......")
        # for k, v in kwargs.items():
        #     self.__setattr__(k, v)
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.model = kwargs["model"]
        self.config = kwargs["config"]
        self.use_crf = self.config.use_crf
        self.average_batch = self.config.average_batch
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(name=self.config.learning_algorithm,
                                   model=self.model,
                                   lr=self.config.learning_rate,
                                   weight_decay=self.config.weight_decay,
                                   grad_clip=self.config.clip_max_norm)
        self.loss_function = self._loss(
            learning_algorithm=self.config.learning_algorithm,
            label_paddingId=self.config.label_paddingId,
            use_crf=self.use_crf)
        print(self.optimizer)
        print(self.loss_function)
        self.best_score = Best_Result()
        self.train_eval, self.dev_eval, self.test_eval = Eval(), Eval(), Eval()
        self.train_iter_len = len(self.train_iter)

    def _loss(self, learning_algorithm, label_paddingId, use_crf=False):
        if use_crf:
            loss_function = self.model.crf_layer.neg_log_likelihood_loss
            return loss_function
        elif learning_algorithm == "SGD":
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                size_average=False)
            return loss_function
        else:
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                size_average=True)
            return loss_function

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):
        """
        :param clip_max_norm_use:  whether to use clip max norm for nn model
        :param clip_max_norm: clip max norm max values [float or None]
        :return:
        """
        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm(self.model.parameters(), max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):
        """
        :param config:  config
        :param epoch:  epoch
        :param new_lr:  learning rate
        :return:
        """
        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, epoch, init_lr):
        """lr decay 

        Args:
            epoch: int, epoch 
            init_lr:  initial lr
        """
        lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):
        """
        :return:
        """
        if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch):
        """
        :param epoch:
        :return:
        """
        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            print("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                print(
                    "Early Stop Train. Best Score Locate on {} Epoch.".format(
                        self.best_score.best_epoch))
                exit()

    @staticmethod
    def _get_model_args(batch_features):
        """
        :param batch_features:  Batch Instance
        :return:
        """
        word = batch_features.word_features
        char = batch_features.char_features
        mask = word > 0
        sentence_length = batch_features.sentence_length
        # desorted_indices = batch_features.desorted_indices
        tags = batch_features.label_features
        return word, char, mask, sentence_length, tags

    def _calculate_loss(self, feats, mask, tags):
        """
        Args:
            feats: size = (batch_size, seq_len, tag_size)
            mask: size = (batch_size, seq_len)
            tags: size = (batch_size, seq_len)
        """
        if not self.use_crf:
            batch_size, max_len = feats.size(0), feats.size(1)
            lstm_feats = feats.view(batch_size * max_len, -1)
            tags = tags.view(-1)
            return self.loss_function(lstm_feats, tags)
        else:
            loss_value = self.loss_function(feats, mask, tags)
        if self.average_batch:
            batch_size = feats.size(0)
            loss_value /= float(batch_size)
        return loss_value

    def train(self):
        """
        :return:
        """
        epochs = self.config.epochs
        clip_max_norm_use = self.config.clip_max_norm_use
        clip_max_norm = self.config.clip_max_norm
        new_lr = self.config.learning_rate

        for epoch in range(1, epochs + 1):
            print("\n## The {} Epoch, All {} Epochs ! ##".format(
                epoch, epochs))
            # new_lr = self._dynamic_lr(config=self.config, epoch=epoch, new_lr=new_lr)
            self.optimizer = self._decay_learning_rate(
                epoch=epoch - 1, init_lr=self.config.learning_rate)
            print("now lr is {}".format(
                self.optimizer.param_groups[0].get("lr")),
                  end="")
            start_time = time.time()
            random.shuffle(self.train_iter)
            self.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            for batch_count, batch_features in enumerate(self.train_iter):
                backward_count += 1
                # self.optimizer.zero_grad()
                word, char, mask, sentence_length, tags = self._get_model_args(
                    batch_features)
                logit = self.model(word, char, sentence_length, train=True)
                loss = self._calculate_loss(logit, mask, tags)
                loss.backward()
                self._clip_model_norm(clip_max_norm_use, clip_max_norm)
                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)
                # self.optimizer.step()
                steps += 1
                if (steps - 1) % self.config.log_interval == 0:
                    self.getAcc(self.train_eval, batch_features, logit,
                                self.config)
                    sys.stdout.write(
                        "\nbatch_count = [{}] , loss is {:.6f}, [TAG-ACC is {:.6f}%]"
                        .format(batch_count + 1, loss.data[0],
                                self.train_eval.acc()))
            end_time = time.time()
            print("\nTrain Time {:.3f}".format(end_time - start_time), end="")
            self.eval(model=self.model, epoch=epoch, config=self.config)
            self._model2file(model=self.model, config=self.config, epoch=epoch)
            self._early_stop(epoch=epoch)

    def eval(self, model, epoch, config):
        """
        :param model: nn model
        :param epoch:  epoch
        :param config:  config
        :return:
        """
        self.dev_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.dev_iter,
                        model,
                        self.dev_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=False)
        eval_end_time = time.time()
        print("Dev Time {:.3f}".format(eval_end_time - eval_start_time))

        self.test_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.test_iter,
                        model,
                        self.test_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=True)
        eval_end_time = time.time()
        print("Test Time {:.3f}".format(eval_end_time - eval_start_time))

    def _model2file(self, model, config, epoch):
        """
        :param model:  nn model
        :param config:  config
        :param epoch:  epoch
        :return:
        """
        if config.save_model and config.save_all_model:
            save_model_all(model, config.save_dir, config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config.save_best_model_path,
                            config.model_name, self.best_score)
        else:
            print()

    def eval_batch(self,
                   data_iter,
                   model,
                   eval_instance,
                   best_score,
                   epoch,
                   config,
                   test=False):
        """
        :param data_iter:  eval batch data iterator
        :param model: eval model
        :param eval_instance:
        :param best_score:
        :param epoch:
        :param config: config
        :param test:  whether to test
        :return: None
        """
        model.eval()
        # eval time
        eval_acc = Eval()
        eval_PRF = EvalPRF()
        gold_labels = []
        predict_labels = []
        for batch_features in data_iter:
            word, char, mask, sentence_length, tags = self._get_model_args(
                batch_features)
            logit = model(word, char, sentence_length, train=False)
            if self.use_crf is False:
                predict_ids = torch_max(logit)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    label_ids = predict_ids[id_batch]
                    predict_label = []
                    for id_word in range(inst.words_size):
                        predict_label.append(
                            config.create_alphabet.label_alphabet.from_id(
                                label_ids[id_word]))
                    gold_labels.append(inst.labels)
                    predict_labels.append(predict_label)
            else:
                path_score, best_paths = model.crf_layer(logit, mask)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    gold_labels.append(inst.labels)
                    label_ids = best_paths[id_batch].cpu().data.numpy(
                    )[:inst.words_size]
                    label = []
                    for i in label_ids:
                        label.append(
                            config.create_alphabet.label_alphabet.from_id(i))
                    predict_labels.append(label)
        for p_label, g_label in zip(predict_labels, gold_labels):
            eval_PRF.evalPRF(predict_labels=p_label,
                             gold_labels=g_label,
                             eval=eval_instance)
        if eval_acc.gold_num == 0:
            eval_acc.gold_num = 1
        p, r, f = eval_instance.getFscore()
        # p, r, f = entity_evalPRF_exact(gold_labels=gold_labels, predict_labels=predict_labels)
        # p, r, f = entity_evalPRF_propor(gold_labels=gold_labels, predict_labels=predict_labels)
        # p, r, f = entity_evalPRF_binary(gold_labels=gold_labels, predict_labels=predict_labels)
        test_flag = "Test"
        if test is False:
            print()
            test_flag = "Dev"
            best_score.current_dev_score = f
            if f >= best_score.best_dev_score:
                best_score.best_dev_score = f
                best_score.best_epoch = epoch
                best_score.best_test = True
        if test is True and best_score.best_test is True:
            best_score.p = p
            best_score.r = r
            best_score.f = f
        print(
            "{} eval: precision = {:.6f}%  recall = {:.6f}% , f-score = {:.6f}%,  [TAG-ACC = {:.6f}%]"
            .format(test_flag, p, r, f, 0.0000))
        if test is True:
            print("The Current Best Dev F-score: {:.6f}, Locate on {} Epoch.".
                  format(best_score.best_dev_score, best_score.best_epoch))
            print(
                "The Current Best Test Result: precision = {:.6f}%  recall = {:.6f}% , f-score = {:.6f}%"
                .format(best_score.p, best_score.r, best_score.f))
        if test is True:
            best_score.best_test = False

    @staticmethod
    def getAcc(eval_acc, batch_features, logit, config):
        """
        :param eval_acc:  eval instance
        :param batch_features:  batch data feature
        :param logit:  model output
        :param config:  config
        :return:
        """
        eval_acc.clear_PRF()
        predict_ids = torch_max(logit)
        for id_batch in range(batch_features.batch_length):
            inst = batch_features.inst[id_batch]
            label_ids = predict_ids[id_batch]
            predict_label = []
            gold_lable = inst.labels
            for id_word in range(inst.words_size):
                predict_label.append(
                    config.create_alphabet.label_alphabet.from_id(
                        label_ids[id_word]))
            assert len(predict_label) == len(gold_lable)
            cor = 0
            for p_lable, g_lable in zip(predict_label, gold_lable):
                if p_lable == g_lable:
                    cor += 1
            eval_acc.correct_num += cor
            eval_acc.gold_num += len(gold_lable)
Ejemplo n.º 3
0
class Train(object):
    """
        Train
    """
    def __init__(self, **kwargs):
        """
        :param kwargs:
        Args of data:
            train_iter : train batch data iterator
            dev_iter : dev batch data iterator
            test_iter : test batch data iterator
        Args of train:
            model : nn model
            config : config
        """
        print("Training Start......")
        # for k, v in kwargs.items():
        #     self.__setattr__(k, v)
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.parser = kwargs["model"]
        self.config = kwargs["config"]
        self.device = self.config.device
        self.cuda = False
        if self.device != cpu_device:
            self.cuda = True
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(
            name=self.config.learning_algorithm,
            model=self.parser.model,
            lr=self.config.learning_rate,
            # weight_decay=self.config.weight_decay, grad_clip=self.config.clip_max_norm,
            weight_decay=self.config.weight_decay,
            grad_clip="None",
            betas=(0.9, 0.9),
            eps=1.0e-12)
        if self.config.learning_algorithm == "SGD":
            self.loss_function = nn.CrossEntropyLoss(reduction="sum")
        else:
            self.loss_function = nn.CrossEntropyLoss(reduction="mean")
        print(self.optimizer)
        self.best_score = Best_Result()
        self.train_iter_len = len(self.train_iter)

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):
        """
        :param clip_max_norm_use:  whether to use clip max norm for nn model
        :param clip_max_norm: clip max norm max values [float or None]
        :return:
        """
        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm_(self.parser.model.parameters(),
                                  max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):
        """
        :param config:  config
        :param epoch:  epoch
        :param new_lr:  learning rate
        :return:
        """
        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            # print("epoch", epoch)
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, epoch, init_lr):
        """
        Args:
            epoch: int, epoch
            init_lr: initial lr
        """
        lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
        # print('learning rate: {0}'.format(lr))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):
        """
        :return:
        """
        if backward_count % config.update_batch_size == 0 or backward_count == self.train_iter_len:
            self._clip_model_norm(self.config.clip_max_norm_use,
                                  self.config.clip_max_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch):
        """
        :param epoch:
        :return:
        """
        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            print("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                print(
                    "Early Stop Train. Best Score Locate on {} Epoch.".format(
                        self.best_score.best_epoch))
                exit()

    def _model2file(self, model, config, epoch):
        """
        :param model:  nn model
        :param config:  config
        :param epoch:  epoch
        :return:
        """
        if config.save_model and config.save_all_model:
            save_model_all(model, config.save_dir, config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config.save_best_model_path,
                            config.model_name, self.best_score)
        else:
            print()

    def train(self):
        """
        :return:
        """
        epochs = self.config.epochs
        new_lr = self.config.learning_rate

        for epoch in range(1, epochs + 1):
            print("\n## The {} Epoch, All {} Epochs ! ##".format(
                epoch, epochs))
            new_lr = self._dynamic_lr(config=self.config,
                                      epoch=epoch,
                                      new_lr=new_lr)
            # self.optimizer = self._decay_learning_rate(epoch=epoch - 1, init_lr=self.config.learning_rate)
            print("now lr is {}".format(
                self.optimizer.param_groups[0].get("lr")),
                  end="")
            start_time = time.time()
            random.shuffle(self.train_iter)
            self.parser.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            overall_arc_correct, overall_label_correct, overall_total_arcs = 0, 0, 0
            for batch_count, batch_features in enumerate(self.train_iter):
                backward_count += 1
                words, ext_words, tags, masks = batch_features.words, batch_features.ext_words, batch_features.tags, \
                                                batch_features.masks
                heads, rels, lengths = batch_features.heads, batch_features.rels, batch_features.lengths
                sumLength = sum(lengths)
                self.parser.forward(words, ext_words, tags, masks)

                loss = self.parser.compute_loss(heads, rels, lengths)
                loss = loss / self.config.update_batch_size
                loss_value = loss.data.cpu().numpy()
                loss.backward()

                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)

                steps += 1
                if (steps - 1) % self.config.log_interval == 0:
                    arc_correct, label_correct, total_arcs = self.parser.compute_accuracy(
                        heads, rels)
                    overall_arc_correct += arc_correct
                    overall_label_correct += label_correct
                    overall_total_arcs += total_arcs
                    uas = overall_arc_correct.item(
                    ) * 100.0 / overall_total_arcs
                    las = overall_label_correct.item(
                    ) * 100.0 / overall_total_arcs
                    sys.stdout.write(
                        "\nbatch_count = [{}/{}] , loss is {:.6f}, length: {}, ARC: {:.6f}, REL: {:.6f}"
                        .format(batch_count + 1, self.train_iter_len,
                                float(loss_value), sumLength, float(uas),
                                float(las)))
            end_time = time.time()
            print("\nTrain Time {:.3f}".format(end_time - start_time), end="")
            self.eval(parser=self.parser, epoch=epoch, config=self.config)
            self._model2file(model=self.parser.model,
                             config=self.config,
                             epoch=epoch)
            self._early_stop(epoch=epoch)

    def eval(self, parser, epoch, config):
        """
        :param parser:
        :param epoch:
        :param config:
        :return:
        """
        eval_start_time = time.time()
        self._eval_batch(self.dev_iter,
                         parser,
                         self.best_score,
                         epoch,
                         config,
                         test=False)
        eval_end_time = time.time()
        print("Dev Time {:.3f}".format(eval_end_time - eval_start_time))

        eval_start_time = time.time()
        self._eval_batch(self.test_iter,
                         parser,
                         self.best_score,
                         epoch,
                         config,
                         test=True)
        eval_end_time = time.time()
        print("Test Time {:.3f}".format(eval_end_time - eval_start_time))

    # self.get_one_batch(batch_features.insts)
    def get_one_batch(self, insts):
        """
        :param insts:
        :return:
        """
        batch = []
        for inst in insts:
            batch.append(inst.sentence)
        return batch

    def _eval_batch(self,
                    data_iter,
                    parser,
                    best_score,
                    epoch,
                    config,
                    test=False):
        """
        :param data_iter:
        :param parser:
        :param vocab:
        :param best_score:
        :param epoch:
        :param config:
        :param test:
        :return:
        """
        parser.model.eval()
        arc_total_test, arc_correct_test, rel_total_test, rel_correct_test = 0, 0, 0, 0
        alphabet = config.alphabet

        for batch_count, batch_features in enumerate(data_iter):
            one_batch = self.get_one_batch(batch_features.insts)
            words, ext_words, tags, masks = batch_features.words, batch_features.ext_words, batch_features.tags, batch_features.masks
            heads, rels, lengths = batch_features.heads, batch_features.rels, batch_features.lengths
            # print()
            # print(heads)
            # print(rels)
            # print(lengths)
            # exit()
            sumLength = sum(lengths)
            count = 0
            arcs_batch, rels_batch = parser.parse(words, ext_words, tags,
                                                  lengths, masks)
            # print(arcs_batch)
            # print(rels_batch)
            # exit()
            for tree in batch_variable_depTree(one_batch, arcs_batch,
                                               rels_batch, lengths, alphabet):
                # printDepTree(output, tree)
                # arc_total, arc_correct, rel_total, rel_correct = evalDepTree(tree, one_batch[count])
                arc_total, arc_correct, rel_total, rel_correct = evalDepTree(
                    one_batch[count], tree)
                arc_total_test += arc_total
                arc_correct_test += arc_correct
                rel_total_test += rel_total
                rel_correct_test += rel_correct
                count += 1

        uas = arc_correct_test * 100.0 / arc_total_test
        las = rel_correct_test * 100.0 / rel_total_test

        f = uas
        # p, r, f = law_p, law_r, law_f

        test_flag = "Test"
        if test is False:
            print()
            test_flag = "Dev"
            best_score.current_dev_score = f
            if f >= best_score.best_dev_score:
                best_score.best_dev_score = f
                best_score.best_epoch = epoch
                best_score.best_test = True
        if test is True and best_score.best_test is True:
            best_score.f = f
        print("{}:".format(test_flag))
        print("UAS = %d/%d = %.2f, LAS = %d/%d =%.2f" %
              (arc_correct_test, arc_total_test, uas, rel_correct_test,
               rel_total_test, las))

        if test is True:
            print("The Current Best Dev score: {:.6f}, Locate on {} Epoch.".
                  format(best_score.best_dev_score, best_score.best_epoch))
        if test is True:
            best_score.best_test = False
Ejemplo n.º 4
0
class Train(object):
    """
        Train
    """
    def __init__(self, **kwargs):
        """
        :param kwargs:
        Args of data:
            train_iter : train batch data iterator
            dev_iter : dev batch data iterator
            test_iter : test batch data iterator
        Args of train:
            model : nn model
            config : config
        """
        print("Training Start......")
        # for k, v in kwargs.items():
        #     self.__setattr__(k, v)
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.model = kwargs["model"]
        self.config = kwargs["config"]
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(name=self.config.learning_algorithm,
                                   model=self.model,
                                   lr=self.config.learning_rate,
                                   weight_decay=self.config.weight_decay,
                                   grad_clip=self.config.clip_max_norm)
        if self.config.learning_algorithm == "SGD":
            self.loss_function = nn.CrossEntropyLoss(size_average=False)
        else:
            self.loss_function = nn.CrossEntropyLoss(size_average=True)
        print(self.optimizer)
        self.best_score = Best_Result()
        self.train_eval, self.dev_eval_seg, self.dev_eval_pos, self.test_eval_seg, self.test_eval_pos = Eval(
        ), Eval(), Eval(), Eval(), Eval()
        self.train_iter_len = len(self.train_iter)

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):
        """
        :param clip_max_norm_use:  whether to use clip max norm for nn model
        :param clip_max_norm: clip max norm max values [float or None]
        :return:
        """
        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm(self.model.parameters(), max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):
        """
        :param config:  config
        :param epoch:  epoch
        :param new_lr:  learning rate
        :return:
        """
        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            # print("epoch", epoch)
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, epoch, init_lr):
        """衰减学习率

        Args:
            epoch: int, 迭代次数
            init_lr: 初始学习率
        """
        lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
        # print('learning rate: {0}'.format(lr))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):
        """
        :return:
        """
        if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch):
        """
        :param epoch:
        :return:
        """
        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            print("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                print(
                    "Early Stop Train. Best Score Locate on {} Epoch.".format(
                        self.best_score.best_epoch))
                exit()

    def _model2file(self, model, config, epoch):
        """
        :param model:  nn model
        :param config:  config
        :param epoch:  epoch
        :return:
        """
        if config.save_model and config.save_all_model:
            save_model_all(model, config.save_dir, config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config.save_best_model_path,
                            config.model_name, self.best_score)
        else:
            print()

    def train(self):
        """
        :return:
        """
        epochs = self.config.epochs
        clip_max_norm_use = self.config.clip_max_norm_use
        clip_max_norm = self.config.clip_max_norm
        new_lr = self.config.learning_rate

        for epoch in range(1, epochs + 1):
            print("\n## The {} Epoch, All {} Epochs ! ##".format(
                epoch, epochs))
            new_lr = self._dynamic_lr(config=self.config,
                                      epoch=epoch,
                                      new_lr=new_lr)
            # self.optimizer = self._decay_learning_rate(epoch=epoch - 1, init_lr=self.config.learning_rate)
            print("now lr is {}".format(
                self.optimizer.param_groups[0].get("lr")),
                  end="")
            start_time = time.time()
            random.shuffle(self.train_iter)
            self.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            for batch_count, batch_features in enumerate(self.train_iter):
                backward_count += 1
                # self.optimizer.zero_grad()
                maxCharSize = batch_features.char_features.size()[1]
                decoder_out, state = self.model(batch_features, train=True)
                self.cal_train_acc(batch_features, self.train_eval,
                                   batch_count, decoder_out, maxCharSize,
                                   self.config)
                loss = torch.nn.functional.nll_loss(
                    decoder_out, batch_features.gold_features)
                loss.backward()
                self._clip_model_norm(clip_max_norm_use, clip_max_norm)
                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)
                # self.optimizer.step()
                steps += 1
                if (steps - 1) % self.config.log_interval == 0:
                    sys.stdout.write(
                        "\nBatch_count = [{}/{}] , Loss is {:.6f} , (Correct/Total_num) = Accuracy ({}/{})"
                        " = {:.6f}%".format(batch_count + 1,
                                            self.train_iter_len, loss.data[0],
                                            self.train_eval.correct_num,
                                            self.train_eval.gold_num,
                                            self.train_eval.acc() * 100))
            end_time = time.time()
            # print("\nTrain Time {:.3f}".format(end_time - start_time), end="")
            print("\nTrain Time {:.4f}".format(end_time - start_time))
            self.eval(model=self.model, epoch=epoch, config=self.config)
            self._model2file(model=self.model, config=self.config, epoch=epoch)
            self._early_stop(epoch=epoch)

    def eval(self, model, epoch, config):
        """
        :param model: nn model
        :param epoch:  epoch
        :param config:  config
        :return:
        """
        self.dev_eval_pos.clear()
        self.dev_eval_seg.clear()
        eval_start_time = time.time()
        self.eval_batch(self.dev_iter,
                        model,
                        self.dev_eval_seg,
                        self.dev_eval_pos,
                        self.best_score,
                        epoch,
                        config,
                        test=False)
        eval_end_time = time.time()
        print("Dev Time {:.4f}".format(eval_end_time - eval_start_time))

        self.test_eval_pos.clear()
        self.test_eval_seg.clear()
        eval_start_time = time.time()
        self.eval_batch(self.test_iter,
                        model,
                        self.test_eval_seg,
                        self.test_eval_pos,
                        self.best_score,
                        epoch,
                        config,
                        test=True)
        eval_end_time = time.time()
        print("Test Time {:.4f}".format(eval_end_time - eval_start_time))

    def eval_batch(self,
                   data_iter,
                   model,
                   eval_seg,
                   eval_pos,
                   best_score,
                   epoch,
                   config,
                   test=False):
        """
        :param data_iter:  eval data iterator
        :param model:  nn model
        :param eval_seg:  seg eval
        :param eval_pos:  pos eval
        :param best_score:  best score
        :param epoch:  current epoch
        :param config:  config
        :param test:  test
        :return:
        """
        model.eval()
        for batch_features in data_iter:
            decoder_out, state = model(batch_features, train=False)
            for i in range(batch_features.batch_length):
                self.jointPRF_Batch(batch_features.inst[i], state.words[i],
                                    state.pos_labels[i], eval_seg, eval_pos)

        # calculate the F-Score
        seg_p, seg_r, seg_f = eval_seg.getFscore()
        pos_p, pos_r, pos_f = eval_pos.getFscore()

        test_flag = "Test"
        if test is False:
            # print()
            test_flag = "Dev"
            best_score.current_dev_score = pos_f
            if pos_f >= best_score.best_dev_score:
                best_score.best_dev_score = pos_f
                best_score.best_epoch = epoch
                best_score.best_test = True
        if test is True and best_score.best_test is True:
            best_score.p = pos_p
            best_score.r = pos_r
            best_score.f = pos_f

        print(test_flag + " ---->")
        print("seg: precision = {:.4f}%  recall = {:.4f}% , f-score = {:.4f}%".
              format(seg_p, seg_r, seg_f))
        print("pos: precision = {:.4f}%  recall = {:.4f}% , f-score = {:.4f}%".
              format(pos_p, pos_r, pos_f))

        if test is True:
            print("The Current Best Dev F-score: {:.4f}%, Locate on {} Epoch.".
                  format(best_score.best_dev_score, best_score.best_epoch))
        if test is True:
            best_score.best_test = False

    @staticmethod
    def jointPRF_Batch(inst, state_words, state_posLabel, seg_eval, pos_eval):
        """
        :param inst:
        :param state_words:
        :param state_posLabel:
        :param seg_eval:
        :param pos_eval:
        :return:
        """
        words = state_words
        posLabels = state_posLabel
        count = 0
        predict_seg = []
        predict_pos = []

        for idx in range(len(words)):
            w = words[idx]
            posLabel = posLabels[idx]
            predict_seg.append('[' + str(count) + ',' + str(count + len(w)) +
                               ']')
            predict_pos.append('[' + str(count) + ',' + str(count + len(w)) +
                               ']' + posLabel)
            count += len(w)

        seg_eval.gold_num += len(inst.gold_seg)
        seg_eval.predict_num += len(predict_seg)
        for p in predict_seg:
            if p in inst.gold_seg:
                seg_eval.correct_num += 1

        pos_eval.gold_num += len(inst.gold_pos)
        pos_eval.predict_num += len(predict_pos)
        for p in predict_pos:
            if p in inst.gold_pos:
                pos_eval.correct_num += 1

    def cal_train_acc(self, batch_features, train_eval, batch_count,
                      decoder_out, maxCharSize, args):
        """
        :param batch_features:
        :param train_eval:
        :param batch_count:
        :param decoder_out:
        :param maxCharSize:
        :param args:
        :return:
        """
        train_eval.clear()
        for id_batch in range(batch_features.batch_length):
            inst = batch_features.inst[id_batch]
            for id_char in range(inst.chars_size):
                actionID = self.getMaxindex(
                    decoder_out[id_batch * maxCharSize + id_char], args)
                if actionID == inst.gold_index[id_char]:
                    train_eval.correct_num += 1
            train_eval.gold_num += inst.chars_size

    @staticmethod
    def getMaxindex(decode_out_acc, config):
        """
        :param decode_out_acc:
        :param config:
        :return:
        """
        max = decode_out_acc.data[0]
        maxIndex = 0
        for idx in range(1, config.label_size):
            if decode_out_acc.data[idx] > max:
                max = decode_out_acc.data[idx]
                maxIndex = idx
        return maxIndex
Ejemplo n.º 5
0
class Train(object):
    def __init__(self, **kwargs):
        self.config = kwargs["config"]
        self.config.logger.info("Training Start......")
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.model = kwargs["model"]
        self.use_crf = self.config.use_crf
        self.target = kwargs["target"]
        self.average_batch = self.config.average_batch
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(name=self.config.learning_algorithm,
                                   model=self.model,
                                   lr=self.config.learning_rate,
                                   weight_decay=self.config.weight_decay,
                                   grad_clip=self.config.clip_max_norm)
        self.loss_function = self._loss(
            learning_algorithm=self.config.learning_algorithm,
            label_paddingId=self.config.arg_paddingId,
            use_crf=self.use_crf)
        self.config.logger.info(self.optimizer)
        self.config.logger.info(self.loss_function)
        self.best_score = Best_Result()
        self.train_eval, self.dev_eval, self.test_eval = Eval(), Eval(), Eval()

    def _loss(self, learning_algorithm, label_paddingId, use_crf=False):
        """
        :param learning_algorithm:
        :param label_paddingId:
        :param use_crf:
        :return:
        """
        if use_crf:
            loss_function = self.model.crf_layer.neg_log_likelihood_loss
            return loss_function
        elif learning_algorithm == "SGD":
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                reduction="sum")
            return loss_function
        else:
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                reduction="mean")
            return loss_function

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):
        """
        :param clip_max_norm_use:  whether to use clip max norm for nn model
        :param clip_max_norm: clip max norm max values [float or None]
        :return:
        """
        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):
        """
        :param config:  config
        :param epoch:  epoch
        :param new_lr:  learning rate
        :return:
        """
        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, epoch, init_lr):
        """lr decay 

        Args:
            epoch: int, epoch 
            init_lr:  initial lr
        """
        lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):
        """
        :param config:
        :param backward_count:
        :return:
        """
        if backward_count % config.backward_batch_size == 0:  # or backward_count == self.train_iter_len:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch, config):
        """
        :param epoch:
        :return:
        """
        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            self.config.logger.info("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                self.end_of_epoch = epoch
                self.config.logger.info(
                    "\n\nEarly Stop Train. Best Score Locate on {} Epoch.".
                    format(self.best_score.best_epoch))
                self.save_training_summary()
                return True  # exit()
            else:
                return False
        else:
            return False

    @staticmethod
    def _get_model_args(batch_features):
        """
        :param batch_features:  Batch Instance
        :return:
        """
        elmo_char_seqs = batch_features.elmo_char_seqs
        elmo_word_seqs = batch_features.elmo_word_seqs
        word = batch_features.word_features
        lang = batch_features.lang
        pos = batch_features.pos_features
        prd = batch_features.prd_features
        x_prd_posi = batch_features.prd_posi_features
        mask = batch_features.mask
        sentence_length = batch_features.sentence_length
        tags = batch_features.label_features
        return elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, tags

    def _calculate_loss(self, feats, mask, tags):
        """
        Args:
            feats: size = (batch_size, seq_len, tag_size)
            mask: size = (batch_size, seq_len)
            tags: size = (batch_size, seq_len)
        """
        if not self.use_crf:
            batch_size, max_len = feats.size(0), feats.size(1)
            lstm_feats = feats.view(batch_size * max_len, -1)
            tags = tags.view(-1)
            return self.loss_function(lstm_feats, tags)
        else:
            loss_value = self.loss_function(feats, mask, tags)
        if self.average_batch:
            batch_size = feats.size(0)
            loss_value /= float(batch_size)
        return loss_value

    def train(self):
        """
        :return:
        """
        epochs = self.config.epochs
        clip_max_norm_use = self.config.clip_max_norm_use
        clip_max_norm = self.config.clip_max_norm
        self.config.logger.info('\n\n')
        self.config.logger.info('=-' * 50)

        for epoch in range(1, epochs + 1):
            self.train_iter.reset_flag4trainset()
            self.config.logger.info("\n\n### Epoch: {}/{} ###".format(
                epoch, epochs))
            self.optimizer = self._decay_learning_rate(
                epoch=epoch - 1, init_lr=self.config.learning_rate)
            self.config.logger.info("current lr: {}".format(
                self.optimizer.param_groups[0].get("lr")))
            start_time = time.time()
            self.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            self.config.logger.info('=-' * 10)
            batch_count = 0
            for batch_features in tqdm.tqdm(self.train_iter):
                batch_count += 1
                backward_count += 1
                elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, tags = self._get_model_args(
                    batch_features)
                logit = self.model(elmo_char_seqs,
                                   elmo_word_seqs,
                                   word,
                                   lang,
                                   pos,
                                   prd,
                                   x_prd_posi,
                                   mask,
                                   sentence_length,
                                   train=True)
                loss = self._calculate_loss(logit, mask, tags)
                loss.backward()
                self._clip_model_norm(clip_max_norm_use, clip_max_norm)
                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)
                steps += 1

            if self.use_crf is True:
                p, r, f, acc_ = self.getAccCRF(self.train_eval, batch_features,
                                               logit, mask, self.config)
            else:
                p, r, f, acc_ = self.getAcc(self.train_eval, batch_features,
                                            logit, self.config)
            self.config.logger.info(
                "batch_count:{} , loss: {:.4f}, p: {:.4f}%  r: {:.4f}% , f: {:.4f}%, ACC: {:.4f}%"
                .format(batch_count, loss.item(), p, r, f, acc_))

            end_time = time.time()
            self.config.logger.info("Train Time {:.3f}".format(end_time -
                                                               start_time))
            self.config.logger.info('=-' * 10)
            self.eval(model=self.model, epoch=epoch, config=self.config)
            self.config.logger.info('=-' * 10)
            if self._early_stop(epoch=epoch, config=self.config):
                return
            self.config.logger.info('=-' * 15)

        self.save_training_summary()

    def save_training_summary(self):
        self.config.logger.info(
            "Copy the last model ckps to {} as backup.".format(
                self.config.save_dir))

        self.config.logger.info(
            "save the training summary at end of the log file.")
        self.config.logger.info("\n")
        self.config.logger.info("*" * 25)

        self.config.logger.info("*" * 10)
        self.config.logger.info("features:")
        if self.config.is_predicate:
            self.config.logger.info("\tpredicate, dim: %d" %
                                    self.config.prd_embed_dim)

        self.config.logger.info("*" * 10)
        self.config.logger.info("model:")
        self.config.logger.info(self.model)

        self.config.logger.info("*" * 10)
        self.config.logger.info("training:")
        self.config.logger.info('\tbatch size: %d' % self.config.batch_size)

        self.config.logger.info("*" * 10)
        self.config.logger.info("best performance:")
        self.config.logger.info("\tbest at epoch: %d" %
                                self.best_score.best_epoch)
        self.config.logger.info("\tdev(%):")
        self.config.logger.info("\t\tprecision, %.5f" %
                                self.best_score.best_dev_p_score)
        self.config.logger.info("\t\trecall, %.5f" %
                                self.best_score.best_dev_r_score)
        self.config.logger.info("\t\tf1, %.5f" %
                                self.best_score.best_dev_f1_score)
        self.config.logger.info("\ttest(%):")
        self.config.logger.info("\t\tprecision, %.5f" % self.best_score.p)
        self.config.logger.info("\t\trecall, %.5f" % self.best_score.r)
        self.config.logger.info("\t\tf1, %.5f" % self.best_score.f)

        self.config.logger.info("*" * 25)

    def eval(self, model, epoch, config):
        """
        :param model: nn model
        :param epoch:  epoch
        :param config:  config
        :return:
        """

        self.dev_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.dev_iter,
                        model,
                        self.dev_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=False)
        eval_end_time = time.time()
        self.config.logger.info("Dev Time: {:.3f}".format(eval_end_time -
                                                          eval_start_time))
        self.config.logger.info('=-' * 10)

        self.test_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.test_iter,
                        model,
                        self.test_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=True)
        eval_end_time = time.time()
        self.config.logger.info("Test Time: {:.3f}".format(eval_end_time -
                                                           eval_start_time))

    def _model2file(self, model, config, epoch):
        """
        :param model:  nn model
        :param config:  config
        :param epoch:  epoch
        :return:
        """
        if config.save_model and config.save_all_model:
            save_model_all(model, config, config.save_model_dir,
                           config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config, config.save_model_dir,
                            config.model_name, self.best_score)
        else:
            self.config.logger.info()

    def eval_batch(self,
                   data_iter,
                   model,
                   eval_instance,
                   best_score,
                   epoch,
                   config,
                   test=False):
        """
        :param data_iter:  eval batch data iterator
        :param model: eval model
        :param eval_instance:
        :param best_score:
        :param epoch:
        :param config: config
        :param test:  whether to test
        :return: None
        """
        test_flag = "Test"
        if test is False:
            test_flag = "Dev"

        model.eval()
        gold_labels = []
        predict_labels = []
        all_sentence_length = []
        for batch_features in tqdm.tqdm(data_iter):
            elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, tags = self._get_model_args(
                batch_features)
            logit = model(elmo_char_seqs,
                          elmo_word_seqs,
                          word,
                          lang,
                          pos,
                          prd,
                          x_prd_posi,
                          mask,
                          sentence_length,
                          train=False)
            all_sentence_length.extend(sentence_length)

            if self.use_crf is False:
                predict_ids = torch_max(logit)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    label_ids = predict_ids[id_batch]
                    predict_label = []
                    for id_word in range(inst.words_size):
                        predict_label.append(config.argvocab.i2c[int(i)])
                    gold_labels.append(inst.labels)
                    predict_labels.append(predict_label)
            else:
                path_score, best_paths = model.crf_layer(logit, mask)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    gold_labels.append(inst.labels)
                    label_ids = best_paths[id_batch].cpu().data.numpy(
                    )[:inst.words_size]
                    label = []
                    for i in label_ids:
                        label.append(config.argvocab.i2c[int(i)])
                    predict_labels.append(label)

        p, r, f, acc_ = eval_instance.getFscore(predict_labels, gold_labels,
                                                all_sentence_length)

        if test is False:

            best_score.current_dev_score = f
            if f >= best_score.best_dev_f1_score:
                best_score.best_dev_f1_score = f
                best_score.best_dev_p_score = p
                best_score.best_dev_r_score = r
                best_score.best_epoch = epoch
                best_score.best_test = True

        if test is True and best_score.best_test is True:  # test
            best_score.p = p
            best_score.r = r
            best_score.f = f

        self.config.logger.info(
            "{} at current epoch, p: {:.4f}%  r: {:.4f}% , f: {:.4f}%,  ACC: {:.3f}%"
            .format(test_flag, p, r, f, acc_))

        if test is False:
            self.config.logger.info(
                "Till now, The Best Dev Result: p: {:.4f}%  r: {:.4f}% , f: {:.4f}%, Locate on {} Epoch."
                .format(best_score.best_dev_p_score,
                        best_score.best_dev_r_score,
                        best_score.best_dev_f1_score, best_score.best_epoch))
        elif test is True:
            self.config.logger.info(
                "Till now, The Best Test Result: p: {:.4f}%  r: {:.4f}% , f: {:.4f}%, Locate on {} Epoch."
                .format(best_score.p, best_score.r, best_score.f,
                        best_score.best_epoch))
            best_score.best_test = False

    def eval_external_batch(self, data_iter, config, meta_info=''):
        """
        :param data_iter:  eval batch data iterator
        :param model: eval model
        :param eval_instance:
        :param best_score:
        :param epoch:
        :param config: config
        :param test:  whether to test
        :return: None
        """
        eval = Eval()

        self.model.eval()

        gold_labels = []
        predict_labels = []
        all_sentence_length = []
        for batch_features in tqdm.tqdm(data_iter):
            elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, tags = self._get_model_args(
                batch_features)
            logit = self.model(elmo_char_seqs,
                               elmo_word_seqs,
                               word,
                               lang,
                               pos,
                               prd,
                               x_prd_posi,
                               mask,
                               sentence_length,
                               train=False)
            all_sentence_length.extend(sentence_length)

            if self.use_crf is False:
                predict_ids = torch_max(logit)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    label_ids = predict_ids[id_batch]
                    predict_label = []
                    for id_word in range(inst.words_size):
                        predict_label.append(config.argvocab.i2c[int(i)])
                    gold_labels.append(inst.labels)
                    predict_labels.append(predict_label)
            else:
                path_score, best_paths = self.model.crf_layer(logit, mask)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    gold_labels.append(inst.labels)
                    label_ids = best_paths[id_batch].cpu().data.numpy(
                    )[:inst.words_size]
                    label = []
                    for i in label_ids:
                        label.append(config.argvocab.i2c[int(i)])
                    predict_labels.append(label)

        p, r, f, acc_ = eval.getFscore(predict_labels, gold_labels,
                                       all_sentence_length)

        self.config.logger.info(
            "eval on {}%, p: {:.4f}%  r: {:.4f}% , f: {:.4f}%, ACC: {:.4f}%".
            format(meta_info, p, r, f, acc_))

    @staticmethod
    def getAcc(eval_train, batch_features, logit, config):
        """
        :param eval_acc:  eval instance
        :param batch_features:  batch data feature
        :param logit:  model output
        :param config:  config
        :return:
        """
        eval_train.clear_PRF()
        predict_ids = torch_max(logit)

        predict_labels = []
        gold_labels = []
        batch_length = []

        for id_batch in range(batch_features.batch_length):
            inst = batch_features.inst[id_batch]
            label_ids = predict_ids[id_batch]
            predict_label = []
            gold_label = inst.labels
            for id_word in range(inst.words_size):
                predict_label.append(config.argvocab.i2c[label_ids[id_word]])

            predict_labels.append(predict_label)
            gold_labels.append(gold_label)
            batch_length.append(inst.words_size)

            assert len(predict_label) == len(gold_label)

        p, r, f, acc_ = eval_train.getFscore(predict_labels, gold_labels,
                                             batch_length)
        return p, r, f, acc_

    def getAccCRF(self, eval_train, batch_features, logit, mask, config):

        eval_train.clear_PRF()

        predict_labels = []
        gold_labels = []
        batch_length = []

        path_score, best_paths = self.model.crf_layer(logit, mask)
        for id_batch in range(batch_features.batch_length):
            inst = batch_features.inst[id_batch]
            gold_labels.append(inst.labels)
            label_ids = best_paths[id_batch].cpu().data.numpy()[:inst.
                                                                words_size]
            label = []
            for i in label_ids:
                label.append(config.argvocab.i2c[int(i)])
            predict_labels.append(label)
            batch_length.append(inst.words_size)

            assert len(label) == len(inst.labels)

        p, r, f, acc_ = eval_train.getFscore(predict_labels, gold_labels,
                                             batch_length)
        return p, r, f, acc_
class Train(object):
    """
        Train
    """
    def __init__(self, **kwargs):
        """
        :param kwargs:
        Args of data:
            train_iter : train batch data iterator
            dev_iter : dev batch data iterator
            test_iter : test batch data iterator
        Args of train:
            model : nn model
            config : config
        """
        print("Training Start......")
        # for k, v in kwargs.items():
        #     self.__setattr__(k, v)
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.model = kwargs["model"]
        self.config = kwargs["config"]
        self.device = self.config.device
        self.cuda = False
        if self.device != cpu_device:
            self.cuda = True
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(name=self.config.learning_algorithm,
                                   model=self.model,
                                   lr=self.config.learning_rate,
                                   weight_decay=self.config.weight_decay,
                                   grad_clip=self.config.clip_max_norm)
        if self.config.learning_algorithm == "SGD":
            self.loss_function = nn.CrossEntropyLoss(reduction="sum")
        else:
            self.loss_function = nn.CrossEntropyLoss(reduction="mean")
            # self.loss_function = nn.MultiLabelSoftMarginLoss(size_average=True)
        print(self.optimizer)
        self.best_score = Best_Result()
        self.train_iter_len = len(self.train_iter)

        # define accu eval
        self.accu_train_eval_micro, self.accu_dev_eval_micro, self.accu_test_eval_micro = Eval(
        ), Eval(), Eval()
        self.accu_train_eval_macro, self.accu_dev_eval_macro, self.accu_test_eval_macro = [], [], []
        for i in range(self.config.accu_class_num):
            self.accu_train_eval_macro.append(Eval())
            self.accu_dev_eval_macro.append(Eval())
            self.accu_test_eval_macro.append(Eval())

        # define law eval
        self.law_train_eval_micro, self.law_dev_eval_micro, self.law_test_eval_micro = Eval(
        ), Eval(), Eval()
        self.law_train_eval_macro, self.law_dev_eval_macro, self.law_test_eval_macro = [], [], []
        for i in range(self.config.law_class_num):
            self.law_train_eval_macro.append(Eval())
            self.law_dev_eval_macro.append(Eval())
            self.law_test_eval_macro.append(Eval())

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):
        """
        :param clip_max_norm_use:  whether to use clip max norm for nn model
        :param clip_max_norm: clip max norm max values [float or None]
        :return:
        """
        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):
        """
        :param config:  config
        :param epoch:  epoch
        :param new_lr:  learning rate
        :return:
        """
        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            # print("epoch", epoch)
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, epoch, init_lr):
        """
        Args:
            epoch: int, epoch
            init_lr: initial lr
        """
        lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
        # print('learning rate: {0}'.format(lr))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):
        """
        :return:
        """
        if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch):
        """
        :param epoch:
        :return:
        """
        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            print("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                print(
                    "Early Stop Train. Best Score Locate on {} Epoch.".format(
                        self.best_score.best_epoch))
                exit()

    def train(self):
        """
        :return:
        """
        epochs = self.config.epochs
        clip_max_norm_use = self.config.clip_max_norm_use
        clip_max_norm = self.config.clip_max_norm
        new_lr = self.config.learning_rate

        for epoch in range(1, epochs + 1):
            print("\n## The {} Epoch, All {} Epochs ! ##".format(
                epoch, epochs))
            new_lr = self._dynamic_lr(config=self.config,
                                      epoch=epoch,
                                      new_lr=new_lr)
            # self.optimizer = self._decay_learning_rate(epoch=epoch - 1, init_lr=self.config.learning_rate)
            print("now lr is {}".format(
                self.optimizer.param_groups[0].get("lr")),
                  end="")
            start_time = time.time()
            random.shuffle(self.train_iter)
            self.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            for batch_count, batch_features in enumerate(self.train_iter):
                backward_count += 1
                # self.optimizer.zero_grad()
                accu, law, e_time, d_time = self.model(batch_features)
                accu_logit = accu.view(
                    accu.size(0) * accu.size(1), accu.size(2))
                law_logit = law.view(law.size(0) * law.size(1), law.size(2))
                # print(accu_logit.size())
                # accu_logit = torch_max_one(accu_logit)
                # law_logit = torch_max_one(law_logit)
                # print(batch_features.accu_label_features.size())
                loss_accu = self.loss_function(
                    accu_logit, batch_features.accu_label_features)
                loss_law = self.loss_function(
                    law_logit, batch_features.law_label_features)
                # total_loss = (loss_accu + loss_law)
                total_loss = (loss_accu + loss_law) / 2
                # loss.backward()
                total_loss.backward()
                self._clip_model_norm(clip_max_norm_use, clip_max_norm)
                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)
                # self.optimizer.step()
                steps += 1
                if (steps - 1) % self.config.log_interval == 0:
                    self.accu_train_eval_micro.clear_PRF()
                    for i in range(self.config.accu_class_num):
                        self.accu_train_eval_macro[i].clear_PRF()
                    F1_measure(accu,
                               batch_features.accu_label_features,
                               self.accu_train_eval_micro,
                               self.accu_train_eval_macro,
                               cuda=self.cuda)
                    (accu_p_avg, accu_r_avg,
                     accu_f_avg), (p_micro, r_micro,
                                   f1_micro), (p_macro_avg, r_macro_avg,
                                               f1_macro_avg) = getFscore_Avg(
                                                   self.accu_train_eval_micro,
                                                   self.accu_train_eval_macro,
                                                   accu.size(1))
                    sys.stdout.write(
                        "\nbatch_count = [{}/{}] , total_loss is {:.6f}, [accu-Micro-F1 is {:.6f}%]"
                        .format(batch_count + 1, self.train_iter_len,
                                total_loss.item(), f1_micro))
                end_time = time.time()
            print("\nTrain Time {:.3f}".format(end_time - start_time), end="")
            self.eval(model=self.model, epoch=epoch, config=self.config)
            self._model2file(model=self.model, config=self.config, epoch=epoch)
            self._early_stop(epoch=epoch)

    def eval(self, model, epoch, config):
        """
        :param model: nn model
        :param epoch:  epoch
        :param config:  config
        :return:
        """
        self.accu_dev_eval_micro.clear_PRF()
        for i in range(self.config.accu_class_num):
            self.accu_dev_eval_macro[i].clear_PRF()
        self.law_dev_eval_micro.clear_PRF()
        for i in range(self.config.law_class_num):
            self.law_dev_eval_macro[i].clear_PRF()
        eval_start_time = time.time()
        self._eval_batch(self.dev_iter,
                         model,
                         self.accu_dev_eval_micro,
                         self.accu_dev_eval_macro,
                         self.law_dev_eval_micro,
                         self.law_dev_eval_macro,
                         self.best_score,
                         epoch,
                         config,
                         test=False)
        eval_end_time = time.time()
        print("Dev Time {:.3f}".format(eval_end_time - eval_start_time))

        self.accu_test_eval_micro.clear_PRF()
        for i in range(self.config.accu_class_num):
            self.accu_test_eval_macro[i].clear_PRF()
        self.law_test_eval_micro.clear_PRF()
        for i in range(self.config.law_class_num):
            self.law_test_eval_macro[i].clear_PRF()
        eval_start_time = time.time()
        self._eval_batch(self.test_iter,
                         model,
                         self.accu_test_eval_micro,
                         self.accu_test_eval_macro,
                         self.law_test_eval_micro,
                         self.law_test_eval_macro,
                         self.best_score,
                         epoch,
                         config,
                         test=True)
        eval_end_time = time.time()
        print("Test Time {:.3f}".format(eval_end_time - eval_start_time))

    def _model2file(self, model, config, epoch):
        """
        :param model:  nn model
        :param config:  config
        :param epoch:  epoch
        :return:
        """
        if config.save_model and config.save_all_model:
            save_model_all(model, config.save_dir, config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config.save_best_model_path,
                            config.model_name, self.best_score)
        else:
            print()

    def _eval_batch(self,
                    data_iter,
                    model,
                    accu_eval_micro,
                    accu_eval_macro,
                    law_eval_micro,
                    law_eval_macro,
                    best_score,
                    epoch,
                    config,
                    test=False):
        """
        :param data_iter:
        :param model:
        :param accu_eval_micro:
        :param accu_eval_macro:
        :param best_score:
        :param epoch:
        :param config:
        :param test:
        :return:
        """
        model.eval()

        for batch_count, batch_features in enumerate(data_iter):
            accu, law, e_time, d_time = model(batch_features)
            F1_measure(accu,
                       batch_features.accu_label_features,
                       accu_eval_micro,
                       accu_eval_macro,
                       cuda=self.cuda)
            F1_measure(law,
                       batch_features.law_label_features,
                       law_eval_micro,
                       law_eval_macro,
                       cuda=self.cuda)

        # get f-score
        accu_macro_micro_avg, accu_micro, accu_macro = getFscore_Avg(
            accu_eval_micro, accu_eval_macro, accu.size(1))
        law_macro_micro_avg, law_micro, law_macro = getFscore_Avg(
            law_eval_micro, law_eval_macro, law.size(1))

        accu_p, accu_r, accu_f = accu_macro_micro_avg
        accu_p_ma, accu_r_ma, accu_f_ma = accu_macro
        accu_p_mi, accu_r_mi, accu_f_mi = accu_micro
        law_p, law_r, law_f = law_macro_micro_avg
        law_p_ma, law_r_ma, law_f_ma = law_macro
        law_p_mi, law_r_mi, law_f_mi = law_micro

        p, r, f = accu_p, accu_r, accu_f
        # p, r, f = law_p, law_r, law_f

        test_flag = "Test"
        if test is False:
            print()
            test_flag = "Dev"
            best_score.current_dev_score = f
            if f >= best_score.best_dev_score:
                best_score.best_dev_score = f
                best_score.best_epoch = epoch
                best_score.best_test = True
        if test is True and best_score.best_test is True:
            best_score.p = p
            best_score.r = r
            best_score.f = f
        print("{}:".format(test_flag))
        print("Macro_Micro_Avg ===>>> ")
        print(
            "Eval: accu    --- Precision = {:.6f}%  Recall = {:.6f}% , F-Score = {:.6f}%"
            .format(accu_p, accu_r, accu_f))
        print(
            "Eval:  law    --- Precision = {:.6f}%  Recall = {:.6f}% , F-Score = {:.6f}%"
            .format(law_p, law_r, law_f))
        print("Macro ===>>> ")
        print(
            "Eval: accu    --- Precision = {:.6f}%  Recall = {:.6f}% , F-Score = {:.6f}%"
            .format(accu_p_ma, accu_r_ma, accu_f_ma))
        print(
            "Eval:  law    --- Precision = {:.6f}%  Recall = {:.6f}% , F-Score = {:.6f}%"
            .format(law_p_ma, law_r_ma, law_f_ma))
        print("Micro ===>>> ")
        print(
            "Eval: accu    --- Precision = {:.6f}%  Recall = {:.6f}% , F-Score = {:.6f}%"
            .format(accu_p_mi, accu_r_mi, accu_f_mi))
        print(
            "Eval:  law    --- Precision = {:.6f}%  Recall = {:.6f}% , F-Score = {:.6f}%"
            .format(law_p_mi, law_r_mi, law_f_mi))

        if test is True:
            print(
                "The Current Best accu Dev F-score: {:.6f}, Locate on {} Epoch."
                .format(best_score.best_dev_score, best_score.best_epoch))
            # print("The Current Best Law Dev F-score: {:.6f}, Locate on {} Epoch.".format(best_score.best_dev_score, best_score.best_epoch))
        if test is True:
            best_score.best_test = False
Ejemplo n.º 7
0
class Train(object):
    def __init__(self, **kwargs):

        self.config = kwargs["config"]
        self.config.logger.info("Training Start......")
        self.train_iter = kwargs["train_iter"]
        self.dev_iter = kwargs["dev_iter"]
        self.test_iter = kwargs["test_iter"]
        self.model = kwargs["model"]
        self.use_crf = self.config.use_crf
        self.average_batch = self.config.average_batch
        self.early_max_patience = self.config.early_max_patience
        self.optimizer = Optimizer(name=self.config.learning_algorithm,
                                   model=self.model,
                                   lr=self.config.learning_rate,
                                   weight_decay=self.config.weight_decay,
                                   grad_clip=self.config.clip_max_norm)
        self.loss_function = self._loss(
            learning_algorithm=self.config.learning_algorithm,
            label_paddingId=self.config.label_paddingId,
            use_crf=self.use_crf)
        self.config.logger.info(self.optimizer)
        self.config.logger.info(self.loss_function)
        self.best_score = Best_Result()
        self.train_eval, self.dev_eval, self.test_eval = Eval(), Eval(), Eval()
        self.train_iter_len = len(self.train_iter)

    def _loss(self, learning_algorithm, label_paddingId, use_crf=False):

        if use_crf:
            loss_function = self.model.crf_layer.neg_log_likelihood_loss
            return loss_function
        elif learning_algorithm == "SGD":
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                reduction="sum")
            return loss_function
        else:
            loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId,
                                                reduction="mean")
            return loss_function

    def _clip_model_norm(self, clip_max_norm_use, clip_max_norm):

        if clip_max_norm_use is True:
            gclip = None if clip_max_norm == "None" else float(clip_max_norm)
            assert isinstance(gclip, float)
            utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip)

    def _dynamic_lr(self, config, epoch, new_lr):

        if config.use_lr_decay is True and epoch > config.max_patience and (
                epoch -
                1) % config.max_patience == 0 and new_lr > config.min_lrate:
            new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate)
            set_lrate(self.optimizer, new_lr)
        return new_lr

    def _decay_learning_rate(self, epoch, init_lr):

        lr = init_lr / (1 + self.config.lr_rate_decay * epoch)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return self.optimizer

    def _optimizer_batch_step(self, config, backward_count):

        if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len:
            self.optimizer.step()
            self.optimizer.zero_grad()

    def _early_stop(self, epoch, config):

        best_epoch = self.best_score.best_epoch
        if epoch > best_epoch:
            self.best_score.early_current_patience += 1
            self.config.logger.info("Dev Has Not Promote {} / {}".format(
                self.best_score.early_current_patience,
                self.early_max_patience))
            if self.best_score.early_current_patience >= self.early_max_patience:
                self.end_of_epoch = epoch
                self.config.logger.info(
                    "\n\nEarly Stop Train. Best Score Locate on {} Epoch.".
                    format(self.best_score.best_epoch))
                self.save_training_summary()
                exit()

    @staticmethod
    def _get_model_args(batch_features):

        word = batch_features.word_features
        mask = word > 0
        sentence_length = batch_features.sentence_length
        tags = batch_features.label_features
        return word, mask, sentence_length, tags

    def _calculate_loss(self, feats, mask, tags):

        if not self.use_crf:
            batch_size, max_len = feats.size(0), feats.size(1)
            lstm_feats = feats.view(batch_size * max_len, -1)
            tags = tags.view(-1)
            return self.loss_function(lstm_feats, tags)
        else:
            loss_value = self.loss_function(feats, mask, tags)
        if self.average_batch:
            batch_size = feats.size(0)
            loss_value /= float(batch_size)
        return loss_value

    def train(self):

        epochs = self.config.epochs
        clip_max_norm_use = self.config.clip_max_norm_use
        clip_max_norm = self.config.clip_max_norm
        new_lr = self.config.learning_rate
        self.config.logger.info('\n\n')
        self.config.logger.info('=-' * 50)
        self.config.logger.info('batch number: %d' % len(self.train_iter))

        for epoch in range(1, epochs + 1):
            self.config.logger.info("\n\n### Epoch: {}/{} ###".format(
                epoch, epochs))
            self.optimizer = self._decay_learning_rate(
                epoch=epoch - 1, init_lr=self.config.learning_rate)
            self.config.logger.info("current lr: {}".format(
                self.optimizer.param_groups[0].get("lr")))
            start_time = time.time()
            random.shuffle(self.train_iter)
            self.model.train()
            steps = 1
            backward_count = 0
            self.optimizer.zero_grad()
            self.config.logger.info('=-' * 10)
            for batch_count, batch_features in enumerate(self.train_iter):
                backward_count += 1
                word, mask, sentence_length, tags = self._get_model_args(
                    batch_features)
                logit = self.model(word, sentence_length, train=True)
                loss = self._calculate_loss(logit, mask, tags)
                loss.backward()
                self._clip_model_norm(clip_max_norm_use, clip_max_norm)
                self._optimizer_batch_step(config=self.config,
                                           backward_count=backward_count)
                steps += 1
                if (steps - 1) % self.config.log_interval == 0:
                    self.getAcc(self.train_eval, batch_features, logit,
                                self.config)
                    self.config.logger.info(
                        "batch_count:{} , loss: {:.4f}, [TAG-ACC: {:.4f}%]".
                        format(batch_count + 1, loss.item(),
                               self.train_eval.acc()))
            end_time = time.time()
            self.config.logger.info("Train Time {:.3f}".format(end_time -
                                                               start_time))
            self.config.logger.info('=-' * 10)
            self.eval(model=self.model, epoch=epoch, config=self.config)
            self.config.logger.info('=-' * 10)
            self._model2file(model=self.model, config=self.config, epoch=epoch)
            self._early_stop(epoch=epoch, config=self.config)
            self.config.logger.info('=-' * 15)
        self.save_training_summary()

    def save_training_summary(self):
        self.config.logger.info(
            "Copy the last model ckps to {} as backup.".format(
                self.config.save_dir))
        shutil.copytree(
            self.config.save_model_dir, "/".join(
                [self.config.save_dir, self.config.save_model_dir + "_bak"]))

        self.config.logger.info(
            "save the training summary at end of the log file.")
        self.config.logger.info("\n")
        self.config.logger.info("*" * 25)

        par_path = os.path.dirname(self.config.train_file)
        self.config.logger.info("dataset:\n\t %s" % par_path)
        self.config.logger.info("\ttrain set count: %d" %
                                self.config.train_cnt)
        self.config.logger.info("\tdev set count: %d" % self.config.dev_cnt)
        self.config.logger.info("\ttest set count: %d" % self.config.test_cnt)

        self.config.logger.info("*" * 10)
        self.config.logger.info("model:")
        self.config.logger.info(self.model)

        self.config.logger.info("*" * 10)
        self.config.logger.info("training:")
        self.config.logger.info('\tbatch size: %d' % self.config.batch_size)
        self.config.logger.info('\tbatch count: %d' % len(self.train_iter))

        self.config.logger.info("*" * 10)
        self.config.logger.info("best performance:")
        self.config.logger.info("\tend at epoch: %d" % self.end_of_epoch)
        self.config.logger.info("\tbest at epoch: %d" %
                                self.best_score.best_epoch)
        self.config.logger.info("\tdev(%):")
        self.config.logger.info("\t\tprecision, %.5f" %
                                self.best_score.best_dev_p_score)
        self.config.logger.info("\t\trecall, %.5f" %
                                self.best_score.best_dev_r_score)
        self.config.logger.info("\t\tf1, %.5f" %
                                self.best_score.best_dev_f1_score)
        self.config.logger.info("\ttest(%):")
        self.config.logger.info("\t\tprecision, %.5f" % self.best_score.p)
        self.config.logger.info("\t\trecall, %.5f" % self.best_score.r)
        self.config.logger.info("\t\tf1, %.5f" % self.best_score.f)

        self.config.logger.info("*" * 25)

    def eval(self, model, epoch, config):

        self.dev_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.dev_iter,
                        model,
                        self.dev_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=False)
        eval_end_time = time.time()
        self.config.logger.info("Dev Time: {:.3f}".format(eval_end_time -
                                                          eval_start_time))
        self.config.logger.info('=-' * 10)

        self.test_eval.clear_PRF()
        eval_start_time = time.time()
        self.eval_batch(self.test_iter,
                        model,
                        self.test_eval,
                        self.best_score,
                        epoch,
                        config,
                        test=True)
        eval_end_time = time.time()
        self.config.logger.info("Test Time: {:.3f}".format(eval_end_time -
                                                           eval_start_time))

    def _model2file(self, model, config, epoch):

        if config.save_model and config.save_all_model:
            save_model_all(model, config, config.save_model_dir,
                           config.model_name, epoch)
        elif config.save_model and config.save_best_model:
            save_best_model(model, config, config.save_model_dir,
                            config.model_name, self.best_score)
        else:
            self.config.logger.info()

    def eval_batch(self,
                   data_iter,
                   model,
                   eval_instance,
                   best_score,
                   epoch,
                   config,
                   test=False):

        test_flag = "Test"
        if test is False:  # dev
            test_flag = "Dev"

        model.eval()  # set flag for pytorch
        eval_PRF = EvalPRF()
        gold_labels = []
        predict_labels = []
        for batch_features in data_iter:
            word, mask, sentence_length, tags = self._get_model_args(
                batch_features)
            logit = model(word, sentence_length, train=False)

            if self.use_crf is False:
                predict_ids = torch_max(logit)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    label_ids = predict_ids[id_batch]
                    predict_label = []
                    for id_word in range(inst.words_size):
                        predict_label.append(
                            config.create_alphabet.label_alphabet.from_id(
                                label_ids[id_word]))
                    gold_labels.append(inst.labels)
                    predict_labels.append(predict_label)
            else:
                path_score, best_paths = model.crf_layer(logit, mask)
                for id_batch in range(batch_features.batch_length):
                    inst = batch_features.inst[id_batch]
                    gold_labels.append(inst.labels)
                    label_ids = best_paths[id_batch].cpu().data.numpy(
                    )[:inst.words_size]
                    label = []
                    for i in label_ids:
                        # self.config.logger.info("\n", i)
                        label.append(
                            config.create_alphabet.label_alphabet.from_id(
                                int(i)))
                    predict_labels.append(label)

        for p_label, g_label in zip(predict_labels, gold_labels):
            eval_PRF.evalPRF(predict_labels=p_label,
                             gold_labels=g_label,
                             eval=eval_instance)

        cor = 0
        totol_leng = sum(
            [len(predict_label) for predict_label in predict_labels])
        for p_lable, g_lable in zip(predict_labels, gold_labels):
            for p_lable_, g_lable_ in zip(p_lable, g_lable):
                if p_lable_ == g_lable_:
                    cor += 1
        acc_ = cor / totol_leng * 100

        p, r, f = eval_instance.getFscore()

        if test is False:  # dev
            best_score.current_dev_score = f
            if f >= best_score.best_dev_f1_score:
                best_score.best_dev_f1_score = f
                best_score.best_dev_p_score = p
                best_score.best_dev_r_score = r
                best_score.best_epoch = epoch
                best_score.best_test = True
        if test is True and best_score.best_test is True:  # test
            best_score.p = p
            best_score.r = r
            best_score.f = f
        self.config.logger.info(
            "{} at current epoch, precision: {:.4f}%  recall: {:.4f}% , f-score: {:.4f}%,  [TAG-ACC: {:.3f}%]"
            .format(test_flag, p, r, f, acc_))
        if test is False:
            self.config.logger.info(
                "Till now, The Best Dev Result: precision: {:.4f}%  recall: {:.4f}% , f-score: {:.4f}%, Locate on {} Epoch."
                .format(best_score.best_dev_p_score,
                        best_score.best_dev_r_score,
                        best_score.best_dev_f1_score, best_score.best_epoch))
        elif test is True:
            self.config.logger.info(
                "Till now, The Best Test Result: precision: {:.4f}%  recall: {:.4f}% , f-score: {:.4f}%, Locate on {} Epoch."
                .format(best_score.p, best_score.r, best_score.f,
                        best_score.best_epoch))
            best_score.best_test = False

    @staticmethod
    def getAcc(eval_acc, batch_features, logit, config):

        eval_acc.clear_PRF()
        predict_ids = torch_max(logit)
        for id_batch in range(batch_features.batch_length):
            inst = batch_features.inst[id_batch]
            label_ids = predict_ids[id_batch]
            predict_label = []
            gold_lable = inst.labels
            for id_word in range(inst.words_size):
                predict_label.append(
                    config.create_alphabet.label_alphabet.from_id(
                        label_ids[id_word]))
            assert len(predict_label) == len(gold_lable)
            cor = 0
            for p_lable, g_lable in zip(predict_label, gold_lable):
                if p_lable == g_lable:
                    cor += 1
            eval_acc.correct_num += cor
            eval_acc.gold_num += len(gold_lable)