class Train(object): """ Train """ def __init__(self, **kwargs): """ :param kwargs: Args of data: train_iter : train batch data iterator dev_iter : dev batch data iterator test_iter : test batch data iterator Args of train: model : nn model config : config """ print("Training Start......") # for k, v in kwargs.items(): # self.__setattr__(k, v) self.train_iter = kwargs["train_iter"] self.dev_iter = kwargs["dev_iter"] self.test_iter = kwargs["test_iter"] self.model = kwargs["model"] self.config = kwargs["config"] self.early_max_patience = self.config.early_max_patience self.optimizer = Optimizer(name=self.config.learning_algorithm, model=self.model, lr=self.config.learning_rate, weight_decay=self.config.weight_decay, grad_clip=self.config.clip_max_norm) self.loss_function = self._loss( learning_algorithm=self.config.learning_algorithm) print(self.optimizer) print(self.loss_function) self.best_score = Best_Result() self.train_iter_len = len(self.train_iter) @staticmethod def _loss(learning_algorithm): """ :param learning_algorithm: :return: """ if learning_algorithm == "SGD": loss_function = nn.CrossEntropyLoss(reduction="sum") return loss_function else: loss_function = nn.CrossEntropyLoss(reduction="mean") return loss_function def _clip_model_norm(self, clip_max_norm_use, clip_max_norm): """ :param clip_max_norm_use: whether to use clip max norm for nn model :param clip_max_norm: clip max norm max values [float or None] :return: """ if clip_max_norm_use is True: gclip = None if clip_max_norm == "None" else float(clip_max_norm) assert isinstance(gclip, float) utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip) def _dynamic_lr(self, config, epoch, new_lr): """ :param config: config :param epoch: epoch :param new_lr: learning rate :return: """ if config.use_lr_decay is True and epoch > config.max_patience and ( epoch - 1) % config.max_patience == 0 and new_lr > config.min_lrate: new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate) set_lrate(self.optimizer, new_lr) return new_lr def _decay_learning_rate(self, config, epoch, init_lr): """lr decay Args: epoch: int, epoch init_lr: initial lr """ if config.use_lr_decay: lr = init_lr / (1 + self.config.lr_rate_decay * epoch) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return self.optimizer def _optimizer_batch_step(self, config, backward_count): """ :return: """ if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len: self.optimizer.step() self.optimizer.zero_grad() def _early_stop(self, epoch): """ :param epoch: :return: """ best_epoch = self.best_score.best_epoch if epoch > best_epoch: self.best_score.early_current_patience += 1 print("Dev Has Not Promote {} / {}".format( self.best_score.early_current_patience, self.early_max_patience)) if self.best_score.early_current_patience >= self.early_max_patience: print( "Early Stop Train. Best Score Locate on {} Epoch.".format( self.best_score.best_epoch)) exit() @staticmethod def _get_model_args(batch_features): """ :param batch_features: Batch Instance :return: """ inst = batch_features.inst word = batch_features.word_features mask = word > 0 sentence_length = batch_features.sentence_length labels = batch_features.label_features batch_size = batch_features.batch_length return inst, word, mask, sentence_length, labels, batch_size def _calculate_loss(self, feats, labels): """ Args: feats: size = (batch_size, seq_len, tag_size) labels: size = (batch_size, seq_len) """ loss_value = self.loss_function(feats, labels) return loss_value def train(self): """ :return: """ epochs = self.config.epochs clip_max_norm_use = self.config.clip_max_norm_use clip_max_norm = self.config.clip_max_norm new_lr = self.config.learning_rate for epoch in range(1, epochs + 1): print("\n## The {} Epoch, All {} Epochs ! ##".format( epoch, epochs)) new_lr = self._dynamic_lr(config=self.config, epoch=epoch, new_lr=new_lr) # self.optimizer = self._decay_learning_rate(config=self.config, epoch=epoch - 1, init_lr=self.config.learning_rate) print("now lr is {}".format( self.optimizer.param_groups[0].get("lr")), end="") start_time = time.time() random.shuffle(self.train_iter) self.model.train() steps = 1 backward_count = 0 self.optimizer.zero_grad() for batch_count, batch_features in enumerate(self.train_iter): backward_count += 1 # self.optimizer.zero_grad() inst, word, mask, sentence_length, labels, batch_size = self._get_model_args( batch_features) logit = self.model(word, sentence_length, train=True) loss = self._calculate_loss(logit, labels) loss.backward() self._clip_model_norm(clip_max_norm_use, clip_max_norm) self._optimizer_batch_step(config=self.config, backward_count=backward_count) # self.optimizer.step() steps += 1 if (steps - 1) % self.config.log_interval == 0: accuracy = self.getAcc(logit, labels, batch_size) sys.stdout.write( "\nbatch_count = [{}] , loss is {:.6f}, [accuracy is {:.6f}%]" .format(batch_count + 1, loss.item(), accuracy)) end_time = time.time() print("\nTrain Time {:.3f}".format(end_time - start_time), end="") self.eval(model=self.model, epoch=epoch, config=self.config) self._model2file(model=self.model, config=self.config, epoch=epoch) self._early_stop(epoch=epoch) def eval(self, model, epoch, config): """ :param model: nn model :param epoch: epoch :param config: config :return: """ eval_start_time = time.time() print('\nmistakes for dev_iter') self.eval_batch(self.dev_iter, model, self.best_score, epoch, config, test=False) eval_end_time = time.time() print("Dev Time {:.3f}".format(eval_end_time - eval_start_time)) print('mistakes for test_iter') eval_start_time = time.time() self.eval_batch(self.test_iter, model, self.best_score, epoch, config, test=True) eval_end_time = time.time() print("Test Time {:.3f}".format(eval_end_time - eval_start_time)) def _model2file(self, model, config, epoch): """ :param model: nn model :param config: config :param epoch: epoch :return: """ if config.save_model and config.save_all_model: save_model_all(model, config.save_dir, config.model_name, epoch) elif config.save_model and config.save_best_model: save_best_model(model, config.save_best_model_path, config.model_name, self.best_score) else: print() def eval_batch(self, data_iter, model, best_score, epoch, config, test=False): """ :param data_iter: eval batch data iterator :param model: eval model :param best_score: :param epoch: :param config: config :param test: whether to test :return: None """ model.eval() # eval time corrects = 0 size = 0 loss = 0 Truelabel = [] Words = [] d = [] for batch_features in data_iter: inst, word, mask, sentence_length, labels, batch_size = self._get_model_args( batch_features) logit = self.model(word, sentence_length, train=False) # 加入拼音model需要修改 loss += self._calculate_loss(logit, labels) size += batch_features.batch_length t = torch.max(logit, 1)[1].view(labels.size()).data p = t.cpu().numpy() p = p.tolist() for i in p: d.append([i]) for k in inst: Truelabel.append(k.label_index) Words.append(k.words) # torch.max(logit, 1)[1] 返回tensor logit每一行最大值的索引 corrects += (torch.max(logit, 1)[1].view( labels.size()).data == labels.data).sum() print("更加详细的评估指标:\n", classification_report(Truelabel, d, digits=5)) assert size is not 0, print("Error") accuracy = float(corrects) / size * 100.0 average_loss = float(loss) / size test_flag = "Test" if test is False: print() test_flag = "Dev" best_score.current_dev_score = accuracy if accuracy >= best_score.best_dev_score: best_score.best_dev_score = accuracy best_score.best_epoch = epoch best_score.best_test = True if test is True and best_score.best_test is True: best_score.p = accuracy print("{} eval: average_loss = {:.6f}, accuracy = {:.6f}%".format( test_flag, average_loss, accuracy)) if test is True: print("The Current Best Dev Accuracy: {:.6f}, Locate on {} Epoch.". format(best_score.best_dev_score, best_score.best_epoch)) print("The Current Best Test Accuracy: accuracy = {:.6f}%".format( best_score.p)) if test is True: best_score.best_test = False @staticmethod def getAcc(logit, target, batch_size): """ :param logit: model predict :param target: gold label :param batch_size: batch size :param config: config :return: """ corrects = (torch.max(logit, 1)[1].view( target.size()).data == target.data).sum() accuracy = float(corrects) / batch_size * 100.0 return accuracy
class Train(object): """ Train """ def __init__(self, **kwargs): """ :param kwargs: Args of data: train_iter : train batch data iterator dev_iter : dev batch data iterator test_iter : test batch data iterator Args of train: model : nn model config : config """ print("Training Start......") # for k, v in kwargs.items(): # self.__setattr__(k, v) self.train_iter = kwargs["train_iter"] self.dev_iter = kwargs["dev_iter"] self.test_iter = kwargs["test_iter"] self.model = kwargs["model"] self.config = kwargs["config"] self.use_crf = self.config.use_crf self.average_batch = self.config.average_batch self.early_max_patience = self.config.early_max_patience self.optimizer = Optimizer(name=self.config.learning_algorithm, model=self.model, lr=self.config.learning_rate, weight_decay=self.config.weight_decay, grad_clip=self.config.clip_max_norm) self.loss_function = self._loss( learning_algorithm=self.config.learning_algorithm, label_paddingId=self.config.label_paddingId, use_crf=self.use_crf) print(self.optimizer) print(self.loss_function) self.best_score = Best_Result() self.train_eval, self.dev_eval, self.test_eval = Eval(), Eval(), Eval() self.train_iter_len = len(self.train_iter) def _loss(self, learning_algorithm, label_paddingId, use_crf=False): if use_crf: loss_function = self.model.crf_layer.neg_log_likelihood_loss return loss_function elif learning_algorithm == "SGD": loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId, size_average=False) return loss_function else: loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId, size_average=True) return loss_function def _clip_model_norm(self, clip_max_norm_use, clip_max_norm): """ :param clip_max_norm_use: whether to use clip max norm for nn model :param clip_max_norm: clip max norm max values [float or None] :return: """ if clip_max_norm_use is True: gclip = None if clip_max_norm == "None" else float(clip_max_norm) assert isinstance(gclip, float) utils.clip_grad_norm(self.model.parameters(), max_norm=gclip) def _dynamic_lr(self, config, epoch, new_lr): """ :param config: config :param epoch: epoch :param new_lr: learning rate :return: """ if config.use_lr_decay is True and epoch > config.max_patience and ( epoch - 1) % config.max_patience == 0 and new_lr > config.min_lrate: new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate) set_lrate(self.optimizer, new_lr) return new_lr def _decay_learning_rate(self, epoch, init_lr): """lr decay Args: epoch: int, epoch init_lr: initial lr """ lr = init_lr / (1 + self.config.lr_rate_decay * epoch) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return self.optimizer def _optimizer_batch_step(self, config, backward_count): """ :return: """ if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len: self.optimizer.step() self.optimizer.zero_grad() def _early_stop(self, epoch): """ :param epoch: :return: """ best_epoch = self.best_score.best_epoch if epoch > best_epoch: self.best_score.early_current_patience += 1 print("Dev Has Not Promote {} / {}".format( self.best_score.early_current_patience, self.early_max_patience)) if self.best_score.early_current_patience >= self.early_max_patience: print( "Early Stop Train. Best Score Locate on {} Epoch.".format( self.best_score.best_epoch)) exit() @staticmethod def _get_model_args(batch_features): """ :param batch_features: Batch Instance :return: """ word = batch_features.word_features char = batch_features.char_features mask = word > 0 sentence_length = batch_features.sentence_length # desorted_indices = batch_features.desorted_indices tags = batch_features.label_features return word, char, mask, sentence_length, tags def _calculate_loss(self, feats, mask, tags): """ Args: feats: size = (batch_size, seq_len, tag_size) mask: size = (batch_size, seq_len) tags: size = (batch_size, seq_len) """ if not self.use_crf: batch_size, max_len = feats.size(0), feats.size(1) lstm_feats = feats.view(batch_size * max_len, -1) tags = tags.view(-1) return self.loss_function(lstm_feats, tags) else: loss_value = self.loss_function(feats, mask, tags) if self.average_batch: batch_size = feats.size(0) loss_value /= float(batch_size) return loss_value def train(self): """ :return: """ epochs = self.config.epochs clip_max_norm_use = self.config.clip_max_norm_use clip_max_norm = self.config.clip_max_norm new_lr = self.config.learning_rate for epoch in range(1, epochs + 1): print("\n## The {} Epoch, All {} Epochs ! ##".format( epoch, epochs)) # new_lr = self._dynamic_lr(config=self.config, epoch=epoch, new_lr=new_lr) self.optimizer = self._decay_learning_rate( epoch=epoch - 1, init_lr=self.config.learning_rate) print("now lr is {}".format( self.optimizer.param_groups[0].get("lr")), end="") start_time = time.time() random.shuffle(self.train_iter) self.model.train() steps = 1 backward_count = 0 self.optimizer.zero_grad() for batch_count, batch_features in enumerate(self.train_iter): backward_count += 1 # self.optimizer.zero_grad() word, char, mask, sentence_length, tags = self._get_model_args( batch_features) logit = self.model(word, char, sentence_length, train=True) loss = self._calculate_loss(logit, mask, tags) loss.backward() self._clip_model_norm(clip_max_norm_use, clip_max_norm) self._optimizer_batch_step(config=self.config, backward_count=backward_count) # self.optimizer.step() steps += 1 if (steps - 1) % self.config.log_interval == 0: self.getAcc(self.train_eval, batch_features, logit, self.config) sys.stdout.write( "\nbatch_count = [{}] , loss is {:.6f}, [TAG-ACC is {:.6f}%]" .format(batch_count + 1, loss.data[0], self.train_eval.acc())) end_time = time.time() print("\nTrain Time {:.3f}".format(end_time - start_time), end="") self.eval(model=self.model, epoch=epoch, config=self.config) self._model2file(model=self.model, config=self.config, epoch=epoch) self._early_stop(epoch=epoch) def eval(self, model, epoch, config): """ :param model: nn model :param epoch: epoch :param config: config :return: """ self.dev_eval.clear_PRF() eval_start_time = time.time() self.eval_batch(self.dev_iter, model, self.dev_eval, self.best_score, epoch, config, test=False) eval_end_time = time.time() print("Dev Time {:.3f}".format(eval_end_time - eval_start_time)) self.test_eval.clear_PRF() eval_start_time = time.time() self.eval_batch(self.test_iter, model, self.test_eval, self.best_score, epoch, config, test=True) eval_end_time = time.time() print("Test Time {:.3f}".format(eval_end_time - eval_start_time)) def _model2file(self, model, config, epoch): """ :param model: nn model :param config: config :param epoch: epoch :return: """ if config.save_model and config.save_all_model: save_model_all(model, config.save_dir, config.model_name, epoch) elif config.save_model and config.save_best_model: save_best_model(model, config.save_best_model_path, config.model_name, self.best_score) else: print() def eval_batch(self, data_iter, model, eval_instance, best_score, epoch, config, test=False): """ :param data_iter: eval batch data iterator :param model: eval model :param eval_instance: :param best_score: :param epoch: :param config: config :param test: whether to test :return: None """ model.eval() # eval time eval_acc = Eval() eval_PRF = EvalPRF() gold_labels = [] predict_labels = [] for batch_features in data_iter: word, char, mask, sentence_length, tags = self._get_model_args( batch_features) logit = model(word, char, sentence_length, train=False) if self.use_crf is False: predict_ids = torch_max(logit) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] label_ids = predict_ids[id_batch] predict_label = [] for id_word in range(inst.words_size): predict_label.append( config.create_alphabet.label_alphabet.from_id( label_ids[id_word])) gold_labels.append(inst.labels) predict_labels.append(predict_label) else: path_score, best_paths = model.crf_layer(logit, mask) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] gold_labels.append(inst.labels) label_ids = best_paths[id_batch].cpu().data.numpy( )[:inst.words_size] label = [] for i in label_ids: label.append( config.create_alphabet.label_alphabet.from_id(i)) predict_labels.append(label) for p_label, g_label in zip(predict_labels, gold_labels): eval_PRF.evalPRF(predict_labels=p_label, gold_labels=g_label, eval=eval_instance) if eval_acc.gold_num == 0: eval_acc.gold_num = 1 p, r, f = eval_instance.getFscore() # p, r, f = entity_evalPRF_exact(gold_labels=gold_labels, predict_labels=predict_labels) # p, r, f = entity_evalPRF_propor(gold_labels=gold_labels, predict_labels=predict_labels) # p, r, f = entity_evalPRF_binary(gold_labels=gold_labels, predict_labels=predict_labels) test_flag = "Test" if test is False: print() test_flag = "Dev" best_score.current_dev_score = f if f >= best_score.best_dev_score: best_score.best_dev_score = f best_score.best_epoch = epoch best_score.best_test = True if test is True and best_score.best_test is True: best_score.p = p best_score.r = r best_score.f = f print( "{} eval: precision = {:.6f}% recall = {:.6f}% , f-score = {:.6f}%, [TAG-ACC = {:.6f}%]" .format(test_flag, p, r, f, 0.0000)) if test is True: print("The Current Best Dev F-score: {:.6f}, Locate on {} Epoch.". format(best_score.best_dev_score, best_score.best_epoch)) print( "The Current Best Test Result: precision = {:.6f}% recall = {:.6f}% , f-score = {:.6f}%" .format(best_score.p, best_score.r, best_score.f)) if test is True: best_score.best_test = False @staticmethod def getAcc(eval_acc, batch_features, logit, config): """ :param eval_acc: eval instance :param batch_features: batch data feature :param logit: model output :param config: config :return: """ eval_acc.clear_PRF() predict_ids = torch_max(logit) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] label_ids = predict_ids[id_batch] predict_label = [] gold_lable = inst.labels for id_word in range(inst.words_size): predict_label.append( config.create_alphabet.label_alphabet.from_id( label_ids[id_word])) assert len(predict_label) == len(gold_lable) cor = 0 for p_lable, g_lable in zip(predict_label, gold_lable): if p_lable == g_lable: cor += 1 eval_acc.correct_num += cor eval_acc.gold_num += len(gold_lable)
class Train(object): """ Train """ def __init__(self, **kwargs): """ :param kwargs: Args of data: train_iter : train batch data iterator dev_iter : dev batch data iterator test_iter : test batch data iterator Args of train: model : nn model config : config """ print("Training Start......") # for k, v in kwargs.items(): # self.__setattr__(k, v) self.train_iter = kwargs["train_iter"] self.dev_iter = kwargs["dev_iter"] self.test_iter = kwargs["test_iter"] self.parser = kwargs["model"] self.config = kwargs["config"] self.device = self.config.device self.cuda = False if self.device != cpu_device: self.cuda = True self.early_max_patience = self.config.early_max_patience self.optimizer = Optimizer( name=self.config.learning_algorithm, model=self.parser.model, lr=self.config.learning_rate, # weight_decay=self.config.weight_decay, grad_clip=self.config.clip_max_norm, weight_decay=self.config.weight_decay, grad_clip="None", betas=(0.9, 0.9), eps=1.0e-12) if self.config.learning_algorithm == "SGD": self.loss_function = nn.CrossEntropyLoss(reduction="sum") else: self.loss_function = nn.CrossEntropyLoss(reduction="mean") print(self.optimizer) self.best_score = Best_Result() self.train_iter_len = len(self.train_iter) def _clip_model_norm(self, clip_max_norm_use, clip_max_norm): """ :param clip_max_norm_use: whether to use clip max norm for nn model :param clip_max_norm: clip max norm max values [float or None] :return: """ if clip_max_norm_use is True: gclip = None if clip_max_norm == "None" else float(clip_max_norm) assert isinstance(gclip, float) utils.clip_grad_norm_(self.parser.model.parameters(), max_norm=gclip) def _dynamic_lr(self, config, epoch, new_lr): """ :param config: config :param epoch: epoch :param new_lr: learning rate :return: """ if config.use_lr_decay is True and epoch > config.max_patience and ( epoch - 1) % config.max_patience == 0 and new_lr > config.min_lrate: # print("epoch", epoch) new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate) set_lrate(self.optimizer, new_lr) return new_lr def _decay_learning_rate(self, epoch, init_lr): """ Args: epoch: int, epoch init_lr: initial lr """ lr = init_lr / (1 + self.config.lr_rate_decay * epoch) # print('learning rate: {0}'.format(lr)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return self.optimizer def _optimizer_batch_step(self, config, backward_count): """ :return: """ if backward_count % config.update_batch_size == 0 or backward_count == self.train_iter_len: self._clip_model_norm(self.config.clip_max_norm_use, self.config.clip_max_norm) self.optimizer.step() self.optimizer.zero_grad() def _early_stop(self, epoch): """ :param epoch: :return: """ best_epoch = self.best_score.best_epoch if epoch > best_epoch: self.best_score.early_current_patience += 1 print("Dev Has Not Promote {} / {}".format( self.best_score.early_current_patience, self.early_max_patience)) if self.best_score.early_current_patience >= self.early_max_patience: print( "Early Stop Train. Best Score Locate on {} Epoch.".format( self.best_score.best_epoch)) exit() def _model2file(self, model, config, epoch): """ :param model: nn model :param config: config :param epoch: epoch :return: """ if config.save_model and config.save_all_model: save_model_all(model, config.save_dir, config.model_name, epoch) elif config.save_model and config.save_best_model: save_best_model(model, config.save_best_model_path, config.model_name, self.best_score) else: print() def train(self): """ :return: """ epochs = self.config.epochs new_lr = self.config.learning_rate for epoch in range(1, epochs + 1): print("\n## The {} Epoch, All {} Epochs ! ##".format( epoch, epochs)) new_lr = self._dynamic_lr(config=self.config, epoch=epoch, new_lr=new_lr) # self.optimizer = self._decay_learning_rate(epoch=epoch - 1, init_lr=self.config.learning_rate) print("now lr is {}".format( self.optimizer.param_groups[0].get("lr")), end="") start_time = time.time() random.shuffle(self.train_iter) self.parser.model.train() steps = 1 backward_count = 0 self.optimizer.zero_grad() overall_arc_correct, overall_label_correct, overall_total_arcs = 0, 0, 0 for batch_count, batch_features in enumerate(self.train_iter): backward_count += 1 words, ext_words, tags, masks = batch_features.words, batch_features.ext_words, batch_features.tags, \ batch_features.masks heads, rels, lengths = batch_features.heads, batch_features.rels, batch_features.lengths sumLength = sum(lengths) self.parser.forward(words, ext_words, tags, masks) loss = self.parser.compute_loss(heads, rels, lengths) loss = loss / self.config.update_batch_size loss_value = loss.data.cpu().numpy() loss.backward() self._optimizer_batch_step(config=self.config, backward_count=backward_count) steps += 1 if (steps - 1) % self.config.log_interval == 0: arc_correct, label_correct, total_arcs = self.parser.compute_accuracy( heads, rels) overall_arc_correct += arc_correct overall_label_correct += label_correct overall_total_arcs += total_arcs uas = overall_arc_correct.item( ) * 100.0 / overall_total_arcs las = overall_label_correct.item( ) * 100.0 / overall_total_arcs sys.stdout.write( "\nbatch_count = [{}/{}] , loss is {:.6f}, length: {}, ARC: {:.6f}, REL: {:.6f}" .format(batch_count + 1, self.train_iter_len, float(loss_value), sumLength, float(uas), float(las))) end_time = time.time() print("\nTrain Time {:.3f}".format(end_time - start_time), end="") self.eval(parser=self.parser, epoch=epoch, config=self.config) self._model2file(model=self.parser.model, config=self.config, epoch=epoch) self._early_stop(epoch=epoch) def eval(self, parser, epoch, config): """ :param parser: :param epoch: :param config: :return: """ eval_start_time = time.time() self._eval_batch(self.dev_iter, parser, self.best_score, epoch, config, test=False) eval_end_time = time.time() print("Dev Time {:.3f}".format(eval_end_time - eval_start_time)) eval_start_time = time.time() self._eval_batch(self.test_iter, parser, self.best_score, epoch, config, test=True) eval_end_time = time.time() print("Test Time {:.3f}".format(eval_end_time - eval_start_time)) # self.get_one_batch(batch_features.insts) def get_one_batch(self, insts): """ :param insts: :return: """ batch = [] for inst in insts: batch.append(inst.sentence) return batch def _eval_batch(self, data_iter, parser, best_score, epoch, config, test=False): """ :param data_iter: :param parser: :param vocab: :param best_score: :param epoch: :param config: :param test: :return: """ parser.model.eval() arc_total_test, arc_correct_test, rel_total_test, rel_correct_test = 0, 0, 0, 0 alphabet = config.alphabet for batch_count, batch_features in enumerate(data_iter): one_batch = self.get_one_batch(batch_features.insts) words, ext_words, tags, masks = batch_features.words, batch_features.ext_words, batch_features.tags, batch_features.masks heads, rels, lengths = batch_features.heads, batch_features.rels, batch_features.lengths # print() # print(heads) # print(rels) # print(lengths) # exit() sumLength = sum(lengths) count = 0 arcs_batch, rels_batch = parser.parse(words, ext_words, tags, lengths, masks) # print(arcs_batch) # print(rels_batch) # exit() for tree in batch_variable_depTree(one_batch, arcs_batch, rels_batch, lengths, alphabet): # printDepTree(output, tree) # arc_total, arc_correct, rel_total, rel_correct = evalDepTree(tree, one_batch[count]) arc_total, arc_correct, rel_total, rel_correct = evalDepTree( one_batch[count], tree) arc_total_test += arc_total arc_correct_test += arc_correct rel_total_test += rel_total rel_correct_test += rel_correct count += 1 uas = arc_correct_test * 100.0 / arc_total_test las = rel_correct_test * 100.0 / rel_total_test f = uas # p, r, f = law_p, law_r, law_f test_flag = "Test" if test is False: print() test_flag = "Dev" best_score.current_dev_score = f if f >= best_score.best_dev_score: best_score.best_dev_score = f best_score.best_epoch = epoch best_score.best_test = True if test is True and best_score.best_test is True: best_score.f = f print("{}:".format(test_flag)) print("UAS = %d/%d = %.2f, LAS = %d/%d =%.2f" % (arc_correct_test, arc_total_test, uas, rel_correct_test, rel_total_test, las)) if test is True: print("The Current Best Dev score: {:.6f}, Locate on {} Epoch.". format(best_score.best_dev_score, best_score.best_epoch)) if test is True: best_score.best_test = False
class Train(object): """ Train """ def __init__(self, **kwargs): """ :param kwargs: Args of data: train_iter : train batch data iterator dev_iter : dev batch data iterator test_iter : test batch data iterator Args of train: model : nn model config : config """ print("Training Start......") # for k, v in kwargs.items(): # self.__setattr__(k, v) self.train_iter = kwargs["train_iter"] self.dev_iter = kwargs["dev_iter"] self.test_iter = kwargs["test_iter"] self.model = kwargs["model"] self.config = kwargs["config"] self.early_max_patience = self.config.early_max_patience self.optimizer = Optimizer(name=self.config.learning_algorithm, model=self.model, lr=self.config.learning_rate, weight_decay=self.config.weight_decay, grad_clip=self.config.clip_max_norm) if self.config.learning_algorithm == "SGD": self.loss_function = nn.CrossEntropyLoss(size_average=False) else: self.loss_function = nn.CrossEntropyLoss(size_average=True) print(self.optimizer) self.best_score = Best_Result() self.train_eval, self.dev_eval_seg, self.dev_eval_pos, self.test_eval_seg, self.test_eval_pos = Eval( ), Eval(), Eval(), Eval(), Eval() self.train_iter_len = len(self.train_iter) def _clip_model_norm(self, clip_max_norm_use, clip_max_norm): """ :param clip_max_norm_use: whether to use clip max norm for nn model :param clip_max_norm: clip max norm max values [float or None] :return: """ if clip_max_norm_use is True: gclip = None if clip_max_norm == "None" else float(clip_max_norm) assert isinstance(gclip, float) utils.clip_grad_norm(self.model.parameters(), max_norm=gclip) def _dynamic_lr(self, config, epoch, new_lr): """ :param config: config :param epoch: epoch :param new_lr: learning rate :return: """ if config.use_lr_decay is True and epoch > config.max_patience and ( epoch - 1) % config.max_patience == 0 and new_lr > config.min_lrate: # print("epoch", epoch) new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate) set_lrate(self.optimizer, new_lr) return new_lr def _decay_learning_rate(self, epoch, init_lr): """衰减学习率 Args: epoch: int, 迭代次数 init_lr: 初始学习率 """ lr = init_lr / (1 + self.config.lr_rate_decay * epoch) # print('learning rate: {0}'.format(lr)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return self.optimizer def _optimizer_batch_step(self, config, backward_count): """ :return: """ if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len: self.optimizer.step() self.optimizer.zero_grad() def _early_stop(self, epoch): """ :param epoch: :return: """ best_epoch = self.best_score.best_epoch if epoch > best_epoch: self.best_score.early_current_patience += 1 print("Dev Has Not Promote {} / {}".format( self.best_score.early_current_patience, self.early_max_patience)) if self.best_score.early_current_patience >= self.early_max_patience: print( "Early Stop Train. Best Score Locate on {} Epoch.".format( self.best_score.best_epoch)) exit() def _model2file(self, model, config, epoch): """ :param model: nn model :param config: config :param epoch: epoch :return: """ if config.save_model and config.save_all_model: save_model_all(model, config.save_dir, config.model_name, epoch) elif config.save_model and config.save_best_model: save_best_model(model, config.save_best_model_path, config.model_name, self.best_score) else: print() def train(self): """ :return: """ epochs = self.config.epochs clip_max_norm_use = self.config.clip_max_norm_use clip_max_norm = self.config.clip_max_norm new_lr = self.config.learning_rate for epoch in range(1, epochs + 1): print("\n## The {} Epoch, All {} Epochs ! ##".format( epoch, epochs)) new_lr = self._dynamic_lr(config=self.config, epoch=epoch, new_lr=new_lr) # self.optimizer = self._decay_learning_rate(epoch=epoch - 1, init_lr=self.config.learning_rate) print("now lr is {}".format( self.optimizer.param_groups[0].get("lr")), end="") start_time = time.time() random.shuffle(self.train_iter) self.model.train() steps = 1 backward_count = 0 self.optimizer.zero_grad() for batch_count, batch_features in enumerate(self.train_iter): backward_count += 1 # self.optimizer.zero_grad() maxCharSize = batch_features.char_features.size()[1] decoder_out, state = self.model(batch_features, train=True) self.cal_train_acc(batch_features, self.train_eval, batch_count, decoder_out, maxCharSize, self.config) loss = torch.nn.functional.nll_loss( decoder_out, batch_features.gold_features) loss.backward() self._clip_model_norm(clip_max_norm_use, clip_max_norm) self._optimizer_batch_step(config=self.config, backward_count=backward_count) # self.optimizer.step() steps += 1 if (steps - 1) % self.config.log_interval == 0: sys.stdout.write( "\nBatch_count = [{}/{}] , Loss is {:.6f} , (Correct/Total_num) = Accuracy ({}/{})" " = {:.6f}%".format(batch_count + 1, self.train_iter_len, loss.data[0], self.train_eval.correct_num, self.train_eval.gold_num, self.train_eval.acc() * 100)) end_time = time.time() # print("\nTrain Time {:.3f}".format(end_time - start_time), end="") print("\nTrain Time {:.4f}".format(end_time - start_time)) self.eval(model=self.model, epoch=epoch, config=self.config) self._model2file(model=self.model, config=self.config, epoch=epoch) self._early_stop(epoch=epoch) def eval(self, model, epoch, config): """ :param model: nn model :param epoch: epoch :param config: config :return: """ self.dev_eval_pos.clear() self.dev_eval_seg.clear() eval_start_time = time.time() self.eval_batch(self.dev_iter, model, self.dev_eval_seg, self.dev_eval_pos, self.best_score, epoch, config, test=False) eval_end_time = time.time() print("Dev Time {:.4f}".format(eval_end_time - eval_start_time)) self.test_eval_pos.clear() self.test_eval_seg.clear() eval_start_time = time.time() self.eval_batch(self.test_iter, model, self.test_eval_seg, self.test_eval_pos, self.best_score, epoch, config, test=True) eval_end_time = time.time() print("Test Time {:.4f}".format(eval_end_time - eval_start_time)) def eval_batch(self, data_iter, model, eval_seg, eval_pos, best_score, epoch, config, test=False): """ :param data_iter: eval data iterator :param model: nn model :param eval_seg: seg eval :param eval_pos: pos eval :param best_score: best score :param epoch: current epoch :param config: config :param test: test :return: """ model.eval() for batch_features in data_iter: decoder_out, state = model(batch_features, train=False) for i in range(batch_features.batch_length): self.jointPRF_Batch(batch_features.inst[i], state.words[i], state.pos_labels[i], eval_seg, eval_pos) # calculate the F-Score seg_p, seg_r, seg_f = eval_seg.getFscore() pos_p, pos_r, pos_f = eval_pos.getFscore() test_flag = "Test" if test is False: # print() test_flag = "Dev" best_score.current_dev_score = pos_f if pos_f >= best_score.best_dev_score: best_score.best_dev_score = pos_f best_score.best_epoch = epoch best_score.best_test = True if test is True and best_score.best_test is True: best_score.p = pos_p best_score.r = pos_r best_score.f = pos_f print(test_flag + " ---->") print("seg: precision = {:.4f}% recall = {:.4f}% , f-score = {:.4f}%". format(seg_p, seg_r, seg_f)) print("pos: precision = {:.4f}% recall = {:.4f}% , f-score = {:.4f}%". format(pos_p, pos_r, pos_f)) if test is True: print("The Current Best Dev F-score: {:.4f}%, Locate on {} Epoch.". format(best_score.best_dev_score, best_score.best_epoch)) if test is True: best_score.best_test = False @staticmethod def jointPRF_Batch(inst, state_words, state_posLabel, seg_eval, pos_eval): """ :param inst: :param state_words: :param state_posLabel: :param seg_eval: :param pos_eval: :return: """ words = state_words posLabels = state_posLabel count = 0 predict_seg = [] predict_pos = [] for idx in range(len(words)): w = words[idx] posLabel = posLabels[idx] predict_seg.append('[' + str(count) + ',' + str(count + len(w)) + ']') predict_pos.append('[' + str(count) + ',' + str(count + len(w)) + ']' + posLabel) count += len(w) seg_eval.gold_num += len(inst.gold_seg) seg_eval.predict_num += len(predict_seg) for p in predict_seg: if p in inst.gold_seg: seg_eval.correct_num += 1 pos_eval.gold_num += len(inst.gold_pos) pos_eval.predict_num += len(predict_pos) for p in predict_pos: if p in inst.gold_pos: pos_eval.correct_num += 1 def cal_train_acc(self, batch_features, train_eval, batch_count, decoder_out, maxCharSize, args): """ :param batch_features: :param train_eval: :param batch_count: :param decoder_out: :param maxCharSize: :param args: :return: """ train_eval.clear() for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] for id_char in range(inst.chars_size): actionID = self.getMaxindex( decoder_out[id_batch * maxCharSize + id_char], args) if actionID == inst.gold_index[id_char]: train_eval.correct_num += 1 train_eval.gold_num += inst.chars_size @staticmethod def getMaxindex(decode_out_acc, config): """ :param decode_out_acc: :param config: :return: """ max = decode_out_acc.data[0] maxIndex = 0 for idx in range(1, config.label_size): if decode_out_acc.data[idx] > max: max = decode_out_acc.data[idx] maxIndex = idx return maxIndex
class Train(object): def __init__(self, **kwargs): self.config = kwargs["config"] self.config.logger.info("Training Start......") self.train_iter = kwargs["train_iter"] self.dev_iter = kwargs["dev_iter"] self.test_iter = kwargs["test_iter"] self.model = kwargs["model"] self.use_crf = self.config.use_crf self.target = kwargs["target"] self.average_batch = self.config.average_batch self.early_max_patience = self.config.early_max_patience self.optimizer = Optimizer(name=self.config.learning_algorithm, model=self.model, lr=self.config.learning_rate, weight_decay=self.config.weight_decay, grad_clip=self.config.clip_max_norm) self.loss_function = self._loss( learning_algorithm=self.config.learning_algorithm, label_paddingId=self.config.arg_paddingId, use_crf=self.use_crf) self.config.logger.info(self.optimizer) self.config.logger.info(self.loss_function) self.best_score = Best_Result() self.train_eval, self.dev_eval, self.test_eval = Eval(), Eval(), Eval() def _loss(self, learning_algorithm, label_paddingId, use_crf=False): """ :param learning_algorithm: :param label_paddingId: :param use_crf: :return: """ if use_crf: loss_function = self.model.crf_layer.neg_log_likelihood_loss return loss_function elif learning_algorithm == "SGD": loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId, reduction="sum") return loss_function else: loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId, reduction="mean") return loss_function def _clip_model_norm(self, clip_max_norm_use, clip_max_norm): """ :param clip_max_norm_use: whether to use clip max norm for nn model :param clip_max_norm: clip max norm max values [float or None] :return: """ if clip_max_norm_use is True: gclip = None if clip_max_norm == "None" else float(clip_max_norm) assert isinstance(gclip, float) utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip) def _dynamic_lr(self, config, epoch, new_lr): """ :param config: config :param epoch: epoch :param new_lr: learning rate :return: """ if config.use_lr_decay is True and epoch > config.max_patience and ( epoch - 1) % config.max_patience == 0 and new_lr > config.min_lrate: new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate) set_lrate(self.optimizer, new_lr) return new_lr def _decay_learning_rate(self, epoch, init_lr): """lr decay Args: epoch: int, epoch init_lr: initial lr """ lr = init_lr / (1 + self.config.lr_rate_decay * epoch) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return self.optimizer def _optimizer_batch_step(self, config, backward_count): """ :param config: :param backward_count: :return: """ if backward_count % config.backward_batch_size == 0: # or backward_count == self.train_iter_len: self.optimizer.step() self.optimizer.zero_grad() def _early_stop(self, epoch, config): """ :param epoch: :return: """ best_epoch = self.best_score.best_epoch if epoch > best_epoch: self.best_score.early_current_patience += 1 self.config.logger.info("Dev Has Not Promote {} / {}".format( self.best_score.early_current_patience, self.early_max_patience)) if self.best_score.early_current_patience >= self.early_max_patience: self.end_of_epoch = epoch self.config.logger.info( "\n\nEarly Stop Train. Best Score Locate on {} Epoch.". format(self.best_score.best_epoch)) self.save_training_summary() return True # exit() else: return False else: return False @staticmethod def _get_model_args(batch_features): """ :param batch_features: Batch Instance :return: """ elmo_char_seqs = batch_features.elmo_char_seqs elmo_word_seqs = batch_features.elmo_word_seqs word = batch_features.word_features lang = batch_features.lang pos = batch_features.pos_features prd = batch_features.prd_features x_prd_posi = batch_features.prd_posi_features mask = batch_features.mask sentence_length = batch_features.sentence_length tags = batch_features.label_features return elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, tags def _calculate_loss(self, feats, mask, tags): """ Args: feats: size = (batch_size, seq_len, tag_size) mask: size = (batch_size, seq_len) tags: size = (batch_size, seq_len) """ if not self.use_crf: batch_size, max_len = feats.size(0), feats.size(1) lstm_feats = feats.view(batch_size * max_len, -1) tags = tags.view(-1) return self.loss_function(lstm_feats, tags) else: loss_value = self.loss_function(feats, mask, tags) if self.average_batch: batch_size = feats.size(0) loss_value /= float(batch_size) return loss_value def train(self): """ :return: """ epochs = self.config.epochs clip_max_norm_use = self.config.clip_max_norm_use clip_max_norm = self.config.clip_max_norm self.config.logger.info('\n\n') self.config.logger.info('=-' * 50) for epoch in range(1, epochs + 1): self.train_iter.reset_flag4trainset() self.config.logger.info("\n\n### Epoch: {}/{} ###".format( epoch, epochs)) self.optimizer = self._decay_learning_rate( epoch=epoch - 1, init_lr=self.config.learning_rate) self.config.logger.info("current lr: {}".format( self.optimizer.param_groups[0].get("lr"))) start_time = time.time() self.model.train() steps = 1 backward_count = 0 self.optimizer.zero_grad() self.config.logger.info('=-' * 10) batch_count = 0 for batch_features in tqdm.tqdm(self.train_iter): batch_count += 1 backward_count += 1 elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, tags = self._get_model_args( batch_features) logit = self.model(elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, train=True) loss = self._calculate_loss(logit, mask, tags) loss.backward() self._clip_model_norm(clip_max_norm_use, clip_max_norm) self._optimizer_batch_step(config=self.config, backward_count=backward_count) steps += 1 if self.use_crf is True: p, r, f, acc_ = self.getAccCRF(self.train_eval, batch_features, logit, mask, self.config) else: p, r, f, acc_ = self.getAcc(self.train_eval, batch_features, logit, self.config) self.config.logger.info( "batch_count:{} , loss: {:.4f}, p: {:.4f}% r: {:.4f}% , f: {:.4f}%, ACC: {:.4f}%" .format(batch_count, loss.item(), p, r, f, acc_)) end_time = time.time() self.config.logger.info("Train Time {:.3f}".format(end_time - start_time)) self.config.logger.info('=-' * 10) self.eval(model=self.model, epoch=epoch, config=self.config) self.config.logger.info('=-' * 10) if self._early_stop(epoch=epoch, config=self.config): return self.config.logger.info('=-' * 15) self.save_training_summary() def save_training_summary(self): self.config.logger.info( "Copy the last model ckps to {} as backup.".format( self.config.save_dir)) self.config.logger.info( "save the training summary at end of the log file.") self.config.logger.info("\n") self.config.logger.info("*" * 25) self.config.logger.info("*" * 10) self.config.logger.info("features:") if self.config.is_predicate: self.config.logger.info("\tpredicate, dim: %d" % self.config.prd_embed_dim) self.config.logger.info("*" * 10) self.config.logger.info("model:") self.config.logger.info(self.model) self.config.logger.info("*" * 10) self.config.logger.info("training:") self.config.logger.info('\tbatch size: %d' % self.config.batch_size) self.config.logger.info("*" * 10) self.config.logger.info("best performance:") self.config.logger.info("\tbest at epoch: %d" % self.best_score.best_epoch) self.config.logger.info("\tdev(%):") self.config.logger.info("\t\tprecision, %.5f" % self.best_score.best_dev_p_score) self.config.logger.info("\t\trecall, %.5f" % self.best_score.best_dev_r_score) self.config.logger.info("\t\tf1, %.5f" % self.best_score.best_dev_f1_score) self.config.logger.info("\ttest(%):") self.config.logger.info("\t\tprecision, %.5f" % self.best_score.p) self.config.logger.info("\t\trecall, %.5f" % self.best_score.r) self.config.logger.info("\t\tf1, %.5f" % self.best_score.f) self.config.logger.info("*" * 25) def eval(self, model, epoch, config): """ :param model: nn model :param epoch: epoch :param config: config :return: """ self.dev_eval.clear_PRF() eval_start_time = time.time() self.eval_batch(self.dev_iter, model, self.dev_eval, self.best_score, epoch, config, test=False) eval_end_time = time.time() self.config.logger.info("Dev Time: {:.3f}".format(eval_end_time - eval_start_time)) self.config.logger.info('=-' * 10) self.test_eval.clear_PRF() eval_start_time = time.time() self.eval_batch(self.test_iter, model, self.test_eval, self.best_score, epoch, config, test=True) eval_end_time = time.time() self.config.logger.info("Test Time: {:.3f}".format(eval_end_time - eval_start_time)) def _model2file(self, model, config, epoch): """ :param model: nn model :param config: config :param epoch: epoch :return: """ if config.save_model and config.save_all_model: save_model_all(model, config, config.save_model_dir, config.model_name, epoch) elif config.save_model and config.save_best_model: save_best_model(model, config, config.save_model_dir, config.model_name, self.best_score) else: self.config.logger.info() def eval_batch(self, data_iter, model, eval_instance, best_score, epoch, config, test=False): """ :param data_iter: eval batch data iterator :param model: eval model :param eval_instance: :param best_score: :param epoch: :param config: config :param test: whether to test :return: None """ test_flag = "Test" if test is False: test_flag = "Dev" model.eval() gold_labels = [] predict_labels = [] all_sentence_length = [] for batch_features in tqdm.tqdm(data_iter): elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, tags = self._get_model_args( batch_features) logit = model(elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, train=False) all_sentence_length.extend(sentence_length) if self.use_crf is False: predict_ids = torch_max(logit) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] label_ids = predict_ids[id_batch] predict_label = [] for id_word in range(inst.words_size): predict_label.append(config.argvocab.i2c[int(i)]) gold_labels.append(inst.labels) predict_labels.append(predict_label) else: path_score, best_paths = model.crf_layer(logit, mask) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] gold_labels.append(inst.labels) label_ids = best_paths[id_batch].cpu().data.numpy( )[:inst.words_size] label = [] for i in label_ids: label.append(config.argvocab.i2c[int(i)]) predict_labels.append(label) p, r, f, acc_ = eval_instance.getFscore(predict_labels, gold_labels, all_sentence_length) if test is False: best_score.current_dev_score = f if f >= best_score.best_dev_f1_score: best_score.best_dev_f1_score = f best_score.best_dev_p_score = p best_score.best_dev_r_score = r best_score.best_epoch = epoch best_score.best_test = True if test is True and best_score.best_test is True: # test best_score.p = p best_score.r = r best_score.f = f self.config.logger.info( "{} at current epoch, p: {:.4f}% r: {:.4f}% , f: {:.4f}%, ACC: {:.3f}%" .format(test_flag, p, r, f, acc_)) if test is False: self.config.logger.info( "Till now, The Best Dev Result: p: {:.4f}% r: {:.4f}% , f: {:.4f}%, Locate on {} Epoch." .format(best_score.best_dev_p_score, best_score.best_dev_r_score, best_score.best_dev_f1_score, best_score.best_epoch)) elif test is True: self.config.logger.info( "Till now, The Best Test Result: p: {:.4f}% r: {:.4f}% , f: {:.4f}%, Locate on {} Epoch." .format(best_score.p, best_score.r, best_score.f, best_score.best_epoch)) best_score.best_test = False def eval_external_batch(self, data_iter, config, meta_info=''): """ :param data_iter: eval batch data iterator :param model: eval model :param eval_instance: :param best_score: :param epoch: :param config: config :param test: whether to test :return: None """ eval = Eval() self.model.eval() gold_labels = [] predict_labels = [] all_sentence_length = [] for batch_features in tqdm.tqdm(data_iter): elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, tags = self._get_model_args( batch_features) logit = self.model(elmo_char_seqs, elmo_word_seqs, word, lang, pos, prd, x_prd_posi, mask, sentence_length, train=False) all_sentence_length.extend(sentence_length) if self.use_crf is False: predict_ids = torch_max(logit) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] label_ids = predict_ids[id_batch] predict_label = [] for id_word in range(inst.words_size): predict_label.append(config.argvocab.i2c[int(i)]) gold_labels.append(inst.labels) predict_labels.append(predict_label) else: path_score, best_paths = self.model.crf_layer(logit, mask) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] gold_labels.append(inst.labels) label_ids = best_paths[id_batch].cpu().data.numpy( )[:inst.words_size] label = [] for i in label_ids: label.append(config.argvocab.i2c[int(i)]) predict_labels.append(label) p, r, f, acc_ = eval.getFscore(predict_labels, gold_labels, all_sentence_length) self.config.logger.info( "eval on {}%, p: {:.4f}% r: {:.4f}% , f: {:.4f}%, ACC: {:.4f}%". format(meta_info, p, r, f, acc_)) @staticmethod def getAcc(eval_train, batch_features, logit, config): """ :param eval_acc: eval instance :param batch_features: batch data feature :param logit: model output :param config: config :return: """ eval_train.clear_PRF() predict_ids = torch_max(logit) predict_labels = [] gold_labels = [] batch_length = [] for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] label_ids = predict_ids[id_batch] predict_label = [] gold_label = inst.labels for id_word in range(inst.words_size): predict_label.append(config.argvocab.i2c[label_ids[id_word]]) predict_labels.append(predict_label) gold_labels.append(gold_label) batch_length.append(inst.words_size) assert len(predict_label) == len(gold_label) p, r, f, acc_ = eval_train.getFscore(predict_labels, gold_labels, batch_length) return p, r, f, acc_ def getAccCRF(self, eval_train, batch_features, logit, mask, config): eval_train.clear_PRF() predict_labels = [] gold_labels = [] batch_length = [] path_score, best_paths = self.model.crf_layer(logit, mask) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] gold_labels.append(inst.labels) label_ids = best_paths[id_batch].cpu().data.numpy()[:inst. words_size] label = [] for i in label_ids: label.append(config.argvocab.i2c[int(i)]) predict_labels.append(label) batch_length.append(inst.words_size) assert len(label) == len(inst.labels) p, r, f, acc_ = eval_train.getFscore(predict_labels, gold_labels, batch_length) return p, r, f, acc_
class Train(object): """ Train """ def __init__(self, **kwargs): """ :param kwargs: Args of data: train_iter : train batch data iterator dev_iter : dev batch data iterator test_iter : test batch data iterator Args of train: model : nn model config : config """ print("Training Start......") # for k, v in kwargs.items(): # self.__setattr__(k, v) self.train_iter = kwargs["train_iter"] self.dev_iter = kwargs["dev_iter"] self.test_iter = kwargs["test_iter"] self.model = kwargs["model"] self.config = kwargs["config"] self.device = self.config.device self.cuda = False if self.device != cpu_device: self.cuda = True self.early_max_patience = self.config.early_max_patience self.optimizer = Optimizer(name=self.config.learning_algorithm, model=self.model, lr=self.config.learning_rate, weight_decay=self.config.weight_decay, grad_clip=self.config.clip_max_norm) if self.config.learning_algorithm == "SGD": self.loss_function = nn.CrossEntropyLoss(reduction="sum") else: self.loss_function = nn.CrossEntropyLoss(reduction="mean") # self.loss_function = nn.MultiLabelSoftMarginLoss(size_average=True) print(self.optimizer) self.best_score = Best_Result() self.train_iter_len = len(self.train_iter) # define accu eval self.accu_train_eval_micro, self.accu_dev_eval_micro, self.accu_test_eval_micro = Eval( ), Eval(), Eval() self.accu_train_eval_macro, self.accu_dev_eval_macro, self.accu_test_eval_macro = [], [], [] for i in range(self.config.accu_class_num): self.accu_train_eval_macro.append(Eval()) self.accu_dev_eval_macro.append(Eval()) self.accu_test_eval_macro.append(Eval()) # define law eval self.law_train_eval_micro, self.law_dev_eval_micro, self.law_test_eval_micro = Eval( ), Eval(), Eval() self.law_train_eval_macro, self.law_dev_eval_macro, self.law_test_eval_macro = [], [], [] for i in range(self.config.law_class_num): self.law_train_eval_macro.append(Eval()) self.law_dev_eval_macro.append(Eval()) self.law_test_eval_macro.append(Eval()) def _clip_model_norm(self, clip_max_norm_use, clip_max_norm): """ :param clip_max_norm_use: whether to use clip max norm for nn model :param clip_max_norm: clip max norm max values [float or None] :return: """ if clip_max_norm_use is True: gclip = None if clip_max_norm == "None" else float(clip_max_norm) assert isinstance(gclip, float) utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip) def _dynamic_lr(self, config, epoch, new_lr): """ :param config: config :param epoch: epoch :param new_lr: learning rate :return: """ if config.use_lr_decay is True and epoch > config.max_patience and ( epoch - 1) % config.max_patience == 0 and new_lr > config.min_lrate: # print("epoch", epoch) new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate) set_lrate(self.optimizer, new_lr) return new_lr def _decay_learning_rate(self, epoch, init_lr): """ Args: epoch: int, epoch init_lr: initial lr """ lr = init_lr / (1 + self.config.lr_rate_decay * epoch) # print('learning rate: {0}'.format(lr)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return self.optimizer def _optimizer_batch_step(self, config, backward_count): """ :return: """ if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len: self.optimizer.step() self.optimizer.zero_grad() def _early_stop(self, epoch): """ :param epoch: :return: """ best_epoch = self.best_score.best_epoch if epoch > best_epoch: self.best_score.early_current_patience += 1 print("Dev Has Not Promote {} / {}".format( self.best_score.early_current_patience, self.early_max_patience)) if self.best_score.early_current_patience >= self.early_max_patience: print( "Early Stop Train. Best Score Locate on {} Epoch.".format( self.best_score.best_epoch)) exit() def train(self): """ :return: """ epochs = self.config.epochs clip_max_norm_use = self.config.clip_max_norm_use clip_max_norm = self.config.clip_max_norm new_lr = self.config.learning_rate for epoch in range(1, epochs + 1): print("\n## The {} Epoch, All {} Epochs ! ##".format( epoch, epochs)) new_lr = self._dynamic_lr(config=self.config, epoch=epoch, new_lr=new_lr) # self.optimizer = self._decay_learning_rate(epoch=epoch - 1, init_lr=self.config.learning_rate) print("now lr is {}".format( self.optimizer.param_groups[0].get("lr")), end="") start_time = time.time() random.shuffle(self.train_iter) self.model.train() steps = 1 backward_count = 0 self.optimizer.zero_grad() for batch_count, batch_features in enumerate(self.train_iter): backward_count += 1 # self.optimizer.zero_grad() accu, law, e_time, d_time = self.model(batch_features) accu_logit = accu.view( accu.size(0) * accu.size(1), accu.size(2)) law_logit = law.view(law.size(0) * law.size(1), law.size(2)) # print(accu_logit.size()) # accu_logit = torch_max_one(accu_logit) # law_logit = torch_max_one(law_logit) # print(batch_features.accu_label_features.size()) loss_accu = self.loss_function( accu_logit, batch_features.accu_label_features) loss_law = self.loss_function( law_logit, batch_features.law_label_features) # total_loss = (loss_accu + loss_law) total_loss = (loss_accu + loss_law) / 2 # loss.backward() total_loss.backward() self._clip_model_norm(clip_max_norm_use, clip_max_norm) self._optimizer_batch_step(config=self.config, backward_count=backward_count) # self.optimizer.step() steps += 1 if (steps - 1) % self.config.log_interval == 0: self.accu_train_eval_micro.clear_PRF() for i in range(self.config.accu_class_num): self.accu_train_eval_macro[i].clear_PRF() F1_measure(accu, batch_features.accu_label_features, self.accu_train_eval_micro, self.accu_train_eval_macro, cuda=self.cuda) (accu_p_avg, accu_r_avg, accu_f_avg), (p_micro, r_micro, f1_micro), (p_macro_avg, r_macro_avg, f1_macro_avg) = getFscore_Avg( self.accu_train_eval_micro, self.accu_train_eval_macro, accu.size(1)) sys.stdout.write( "\nbatch_count = [{}/{}] , total_loss is {:.6f}, [accu-Micro-F1 is {:.6f}%]" .format(batch_count + 1, self.train_iter_len, total_loss.item(), f1_micro)) end_time = time.time() print("\nTrain Time {:.3f}".format(end_time - start_time), end="") self.eval(model=self.model, epoch=epoch, config=self.config) self._model2file(model=self.model, config=self.config, epoch=epoch) self._early_stop(epoch=epoch) def eval(self, model, epoch, config): """ :param model: nn model :param epoch: epoch :param config: config :return: """ self.accu_dev_eval_micro.clear_PRF() for i in range(self.config.accu_class_num): self.accu_dev_eval_macro[i].clear_PRF() self.law_dev_eval_micro.clear_PRF() for i in range(self.config.law_class_num): self.law_dev_eval_macro[i].clear_PRF() eval_start_time = time.time() self._eval_batch(self.dev_iter, model, self.accu_dev_eval_micro, self.accu_dev_eval_macro, self.law_dev_eval_micro, self.law_dev_eval_macro, self.best_score, epoch, config, test=False) eval_end_time = time.time() print("Dev Time {:.3f}".format(eval_end_time - eval_start_time)) self.accu_test_eval_micro.clear_PRF() for i in range(self.config.accu_class_num): self.accu_test_eval_macro[i].clear_PRF() self.law_test_eval_micro.clear_PRF() for i in range(self.config.law_class_num): self.law_test_eval_macro[i].clear_PRF() eval_start_time = time.time() self._eval_batch(self.test_iter, model, self.accu_test_eval_micro, self.accu_test_eval_macro, self.law_test_eval_micro, self.law_test_eval_macro, self.best_score, epoch, config, test=True) eval_end_time = time.time() print("Test Time {:.3f}".format(eval_end_time - eval_start_time)) def _model2file(self, model, config, epoch): """ :param model: nn model :param config: config :param epoch: epoch :return: """ if config.save_model and config.save_all_model: save_model_all(model, config.save_dir, config.model_name, epoch) elif config.save_model and config.save_best_model: save_best_model(model, config.save_best_model_path, config.model_name, self.best_score) else: print() def _eval_batch(self, data_iter, model, accu_eval_micro, accu_eval_macro, law_eval_micro, law_eval_macro, best_score, epoch, config, test=False): """ :param data_iter: :param model: :param accu_eval_micro: :param accu_eval_macro: :param best_score: :param epoch: :param config: :param test: :return: """ model.eval() for batch_count, batch_features in enumerate(data_iter): accu, law, e_time, d_time = model(batch_features) F1_measure(accu, batch_features.accu_label_features, accu_eval_micro, accu_eval_macro, cuda=self.cuda) F1_measure(law, batch_features.law_label_features, law_eval_micro, law_eval_macro, cuda=self.cuda) # get f-score accu_macro_micro_avg, accu_micro, accu_macro = getFscore_Avg( accu_eval_micro, accu_eval_macro, accu.size(1)) law_macro_micro_avg, law_micro, law_macro = getFscore_Avg( law_eval_micro, law_eval_macro, law.size(1)) accu_p, accu_r, accu_f = accu_macro_micro_avg accu_p_ma, accu_r_ma, accu_f_ma = accu_macro accu_p_mi, accu_r_mi, accu_f_mi = accu_micro law_p, law_r, law_f = law_macro_micro_avg law_p_ma, law_r_ma, law_f_ma = law_macro law_p_mi, law_r_mi, law_f_mi = law_micro p, r, f = accu_p, accu_r, accu_f # p, r, f = law_p, law_r, law_f test_flag = "Test" if test is False: print() test_flag = "Dev" best_score.current_dev_score = f if f >= best_score.best_dev_score: best_score.best_dev_score = f best_score.best_epoch = epoch best_score.best_test = True if test is True and best_score.best_test is True: best_score.p = p best_score.r = r best_score.f = f print("{}:".format(test_flag)) print("Macro_Micro_Avg ===>>> ") print( "Eval: accu --- Precision = {:.6f}% Recall = {:.6f}% , F-Score = {:.6f}%" .format(accu_p, accu_r, accu_f)) print( "Eval: law --- Precision = {:.6f}% Recall = {:.6f}% , F-Score = {:.6f}%" .format(law_p, law_r, law_f)) print("Macro ===>>> ") print( "Eval: accu --- Precision = {:.6f}% Recall = {:.6f}% , F-Score = {:.6f}%" .format(accu_p_ma, accu_r_ma, accu_f_ma)) print( "Eval: law --- Precision = {:.6f}% Recall = {:.6f}% , F-Score = {:.6f}%" .format(law_p_ma, law_r_ma, law_f_ma)) print("Micro ===>>> ") print( "Eval: accu --- Precision = {:.6f}% Recall = {:.6f}% , F-Score = {:.6f}%" .format(accu_p_mi, accu_r_mi, accu_f_mi)) print( "Eval: law --- Precision = {:.6f}% Recall = {:.6f}% , F-Score = {:.6f}%" .format(law_p_mi, law_r_mi, law_f_mi)) if test is True: print( "The Current Best accu Dev F-score: {:.6f}, Locate on {} Epoch." .format(best_score.best_dev_score, best_score.best_epoch)) # print("The Current Best Law Dev F-score: {:.6f}, Locate on {} Epoch.".format(best_score.best_dev_score, best_score.best_epoch)) if test is True: best_score.best_test = False
class Train(object): def __init__(self, **kwargs): self.config = kwargs["config"] self.config.logger.info("Training Start......") self.train_iter = kwargs["train_iter"] self.dev_iter = kwargs["dev_iter"] self.test_iter = kwargs["test_iter"] self.model = kwargs["model"] self.use_crf = self.config.use_crf self.average_batch = self.config.average_batch self.early_max_patience = self.config.early_max_patience self.optimizer = Optimizer(name=self.config.learning_algorithm, model=self.model, lr=self.config.learning_rate, weight_decay=self.config.weight_decay, grad_clip=self.config.clip_max_norm) self.loss_function = self._loss( learning_algorithm=self.config.learning_algorithm, label_paddingId=self.config.label_paddingId, use_crf=self.use_crf) self.config.logger.info(self.optimizer) self.config.logger.info(self.loss_function) self.best_score = Best_Result() self.train_eval, self.dev_eval, self.test_eval = Eval(), Eval(), Eval() self.train_iter_len = len(self.train_iter) def _loss(self, learning_algorithm, label_paddingId, use_crf=False): if use_crf: loss_function = self.model.crf_layer.neg_log_likelihood_loss return loss_function elif learning_algorithm == "SGD": loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId, reduction="sum") return loss_function else: loss_function = nn.CrossEntropyLoss(ignore_index=label_paddingId, reduction="mean") return loss_function def _clip_model_norm(self, clip_max_norm_use, clip_max_norm): if clip_max_norm_use is True: gclip = None if clip_max_norm == "None" else float(clip_max_norm) assert isinstance(gclip, float) utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip) def _dynamic_lr(self, config, epoch, new_lr): if config.use_lr_decay is True and epoch > config.max_patience and ( epoch - 1) % config.max_patience == 0 and new_lr > config.min_lrate: new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate) set_lrate(self.optimizer, new_lr) return new_lr def _decay_learning_rate(self, epoch, init_lr): lr = init_lr / (1 + self.config.lr_rate_decay * epoch) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return self.optimizer def _optimizer_batch_step(self, config, backward_count): if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len: self.optimizer.step() self.optimizer.zero_grad() def _early_stop(self, epoch, config): best_epoch = self.best_score.best_epoch if epoch > best_epoch: self.best_score.early_current_patience += 1 self.config.logger.info("Dev Has Not Promote {} / {}".format( self.best_score.early_current_patience, self.early_max_patience)) if self.best_score.early_current_patience >= self.early_max_patience: self.end_of_epoch = epoch self.config.logger.info( "\n\nEarly Stop Train. Best Score Locate on {} Epoch.". format(self.best_score.best_epoch)) self.save_training_summary() exit() @staticmethod def _get_model_args(batch_features): word = batch_features.word_features mask = word > 0 sentence_length = batch_features.sentence_length tags = batch_features.label_features return word, mask, sentence_length, tags def _calculate_loss(self, feats, mask, tags): if not self.use_crf: batch_size, max_len = feats.size(0), feats.size(1) lstm_feats = feats.view(batch_size * max_len, -1) tags = tags.view(-1) return self.loss_function(lstm_feats, tags) else: loss_value = self.loss_function(feats, mask, tags) if self.average_batch: batch_size = feats.size(0) loss_value /= float(batch_size) return loss_value def train(self): epochs = self.config.epochs clip_max_norm_use = self.config.clip_max_norm_use clip_max_norm = self.config.clip_max_norm new_lr = self.config.learning_rate self.config.logger.info('\n\n') self.config.logger.info('=-' * 50) self.config.logger.info('batch number: %d' % len(self.train_iter)) for epoch in range(1, epochs + 1): self.config.logger.info("\n\n### Epoch: {}/{} ###".format( epoch, epochs)) self.optimizer = self._decay_learning_rate( epoch=epoch - 1, init_lr=self.config.learning_rate) self.config.logger.info("current lr: {}".format( self.optimizer.param_groups[0].get("lr"))) start_time = time.time() random.shuffle(self.train_iter) self.model.train() steps = 1 backward_count = 0 self.optimizer.zero_grad() self.config.logger.info('=-' * 10) for batch_count, batch_features in enumerate(self.train_iter): backward_count += 1 word, mask, sentence_length, tags = self._get_model_args( batch_features) logit = self.model(word, sentence_length, train=True) loss = self._calculate_loss(logit, mask, tags) loss.backward() self._clip_model_norm(clip_max_norm_use, clip_max_norm) self._optimizer_batch_step(config=self.config, backward_count=backward_count) steps += 1 if (steps - 1) % self.config.log_interval == 0: self.getAcc(self.train_eval, batch_features, logit, self.config) self.config.logger.info( "batch_count:{} , loss: {:.4f}, [TAG-ACC: {:.4f}%]". format(batch_count + 1, loss.item(), self.train_eval.acc())) end_time = time.time() self.config.logger.info("Train Time {:.3f}".format(end_time - start_time)) self.config.logger.info('=-' * 10) self.eval(model=self.model, epoch=epoch, config=self.config) self.config.logger.info('=-' * 10) self._model2file(model=self.model, config=self.config, epoch=epoch) self._early_stop(epoch=epoch, config=self.config) self.config.logger.info('=-' * 15) self.save_training_summary() def save_training_summary(self): self.config.logger.info( "Copy the last model ckps to {} as backup.".format( self.config.save_dir)) shutil.copytree( self.config.save_model_dir, "/".join( [self.config.save_dir, self.config.save_model_dir + "_bak"])) self.config.logger.info( "save the training summary at end of the log file.") self.config.logger.info("\n") self.config.logger.info("*" * 25) par_path = os.path.dirname(self.config.train_file) self.config.logger.info("dataset:\n\t %s" % par_path) self.config.logger.info("\ttrain set count: %d" % self.config.train_cnt) self.config.logger.info("\tdev set count: %d" % self.config.dev_cnt) self.config.logger.info("\ttest set count: %d" % self.config.test_cnt) self.config.logger.info("*" * 10) self.config.logger.info("model:") self.config.logger.info(self.model) self.config.logger.info("*" * 10) self.config.logger.info("training:") self.config.logger.info('\tbatch size: %d' % self.config.batch_size) self.config.logger.info('\tbatch count: %d' % len(self.train_iter)) self.config.logger.info("*" * 10) self.config.logger.info("best performance:") self.config.logger.info("\tend at epoch: %d" % self.end_of_epoch) self.config.logger.info("\tbest at epoch: %d" % self.best_score.best_epoch) self.config.logger.info("\tdev(%):") self.config.logger.info("\t\tprecision, %.5f" % self.best_score.best_dev_p_score) self.config.logger.info("\t\trecall, %.5f" % self.best_score.best_dev_r_score) self.config.logger.info("\t\tf1, %.5f" % self.best_score.best_dev_f1_score) self.config.logger.info("\ttest(%):") self.config.logger.info("\t\tprecision, %.5f" % self.best_score.p) self.config.logger.info("\t\trecall, %.5f" % self.best_score.r) self.config.logger.info("\t\tf1, %.5f" % self.best_score.f) self.config.logger.info("*" * 25) def eval(self, model, epoch, config): self.dev_eval.clear_PRF() eval_start_time = time.time() self.eval_batch(self.dev_iter, model, self.dev_eval, self.best_score, epoch, config, test=False) eval_end_time = time.time() self.config.logger.info("Dev Time: {:.3f}".format(eval_end_time - eval_start_time)) self.config.logger.info('=-' * 10) self.test_eval.clear_PRF() eval_start_time = time.time() self.eval_batch(self.test_iter, model, self.test_eval, self.best_score, epoch, config, test=True) eval_end_time = time.time() self.config.logger.info("Test Time: {:.3f}".format(eval_end_time - eval_start_time)) def _model2file(self, model, config, epoch): if config.save_model and config.save_all_model: save_model_all(model, config, config.save_model_dir, config.model_name, epoch) elif config.save_model and config.save_best_model: save_best_model(model, config, config.save_model_dir, config.model_name, self.best_score) else: self.config.logger.info() def eval_batch(self, data_iter, model, eval_instance, best_score, epoch, config, test=False): test_flag = "Test" if test is False: # dev test_flag = "Dev" model.eval() # set flag for pytorch eval_PRF = EvalPRF() gold_labels = [] predict_labels = [] for batch_features in data_iter: word, mask, sentence_length, tags = self._get_model_args( batch_features) logit = model(word, sentence_length, train=False) if self.use_crf is False: predict_ids = torch_max(logit) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] label_ids = predict_ids[id_batch] predict_label = [] for id_word in range(inst.words_size): predict_label.append( config.create_alphabet.label_alphabet.from_id( label_ids[id_word])) gold_labels.append(inst.labels) predict_labels.append(predict_label) else: path_score, best_paths = model.crf_layer(logit, mask) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] gold_labels.append(inst.labels) label_ids = best_paths[id_batch].cpu().data.numpy( )[:inst.words_size] label = [] for i in label_ids: # self.config.logger.info("\n", i) label.append( config.create_alphabet.label_alphabet.from_id( int(i))) predict_labels.append(label) for p_label, g_label in zip(predict_labels, gold_labels): eval_PRF.evalPRF(predict_labels=p_label, gold_labels=g_label, eval=eval_instance) cor = 0 totol_leng = sum( [len(predict_label) for predict_label in predict_labels]) for p_lable, g_lable in zip(predict_labels, gold_labels): for p_lable_, g_lable_ in zip(p_lable, g_lable): if p_lable_ == g_lable_: cor += 1 acc_ = cor / totol_leng * 100 p, r, f = eval_instance.getFscore() if test is False: # dev best_score.current_dev_score = f if f >= best_score.best_dev_f1_score: best_score.best_dev_f1_score = f best_score.best_dev_p_score = p best_score.best_dev_r_score = r best_score.best_epoch = epoch best_score.best_test = True if test is True and best_score.best_test is True: # test best_score.p = p best_score.r = r best_score.f = f self.config.logger.info( "{} at current epoch, precision: {:.4f}% recall: {:.4f}% , f-score: {:.4f}%, [TAG-ACC: {:.3f}%]" .format(test_flag, p, r, f, acc_)) if test is False: self.config.logger.info( "Till now, The Best Dev Result: precision: {:.4f}% recall: {:.4f}% , f-score: {:.4f}%, Locate on {} Epoch." .format(best_score.best_dev_p_score, best_score.best_dev_r_score, best_score.best_dev_f1_score, best_score.best_epoch)) elif test is True: self.config.logger.info( "Till now, The Best Test Result: precision: {:.4f}% recall: {:.4f}% , f-score: {:.4f}%, Locate on {} Epoch." .format(best_score.p, best_score.r, best_score.f, best_score.best_epoch)) best_score.best_test = False @staticmethod def getAcc(eval_acc, batch_features, logit, config): eval_acc.clear_PRF() predict_ids = torch_max(logit) for id_batch in range(batch_features.batch_length): inst = batch_features.inst[id_batch] label_ids = predict_ids[id_batch] predict_label = [] gold_lable = inst.labels for id_word in range(inst.words_size): predict_label.append( config.create_alphabet.label_alphabet.from_id( label_ids[id_word])) assert len(predict_label) == len(gold_lable) cor = 0 for p_lable, g_lable in zip(predict_label, gold_lable): if p_lable == g_lable: cor += 1 eval_acc.correct_num += cor eval_acc.gold_num += len(gold_lable)