def test_auto_encoding_type_infer(self): # 检查是否可以自动check encode的类型 vocabs = {} import random for encoding_type in ['bio', 'bioes', 'bmeso']: vocab = Vocabulary(unknown=None, padding=None) for i in range(random.randint(10, 100)): label = str(random.randint(1, 10)) for tag in encoding_type: if tag != 'o': vocab.add_word(f'{tag}-{label}') else: vocab.add_word('o') vocabs[encoding_type] = vocab for e in ['bio', 'bioes', 'bmeso']: with self.subTest(e=e): metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) assert metric.encoding_type == e bmes_vocab = _generate_tags('bmes') vocab = Vocabulary() for tag, index in bmes_vocab.items(): vocab.add_word(tag) metric = SpanFPreRecMetric(vocab) assert metric.encoding_type == 'bmes' # 一些无法check的情况 vocab = Vocabulary() for i in range(10): vocab.add_word(str(i)) with self.assertRaises(Exception): metric = SpanFPreRecMetric(vocab)
def test_encoding_type(self): # 检查传入的tag_vocab与encoding_type不符合时,是否会报错 vocabs = {} import random from itertools import product for encoding_type in ['bio', 'bioes', 'bmeso']: vocab = Vocabulary(unknown=None, padding=None) for i in range(random.randint(10, 100)): label = str(random.randint(1, 10)) for tag in encoding_type: if tag!='o': vocab.add_word(f'{tag}-{label}') else: vocab.add_word('o') vocabs[encoding_type] = vocab for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): with self.subTest(e1=e1, e2=e2): if e1==e2: metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2) else: s2 = set(e2) s2.update(set(e1)) if s2==set(e2): continue with self.assertRaises(AssertionError): metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2) for encoding_type in ['bio', 'bioes', 'bmeso']: with self.assertRaises(AssertionError): metric = SpanFPreRecMetric(vocabs[encoding_type], encoding_type='bmes') with self.assertWarns(Warning): vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) metric = SpanFPreRecMetric(vocab, encoding_type='bmeso') vocab = Vocabulary().add_word_lst(list('bmes')) metric = SpanFPreRecMetric(vocab, encoding_type='bmeso')
def test(self, file_path): test_data = ConllxDataLoader().load(file_path) save_dict = self._dict tag_vocab = save_dict["tag_vocab"] pipeline = save_dict["pipeline"] index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) pipeline.pipeline = [index_tag] + pipeline.pipeline test_data.rename_field("pos_tags", "tag") pipeline(test_data) test_data.set_target("truth") prediction = test_data.field_arrays["predict"].content truth = test_data.field_arrays["truth"].content seq_len = test_data.field_arrays["word_seq_origin_len"].content # padding by hand max_length = max([len(seq) for seq in prediction]) for idx in range(len(prediction)): prediction[idx] = list( prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", seq_lens="word_seq_origin_len") evaluator( { "predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len) }, {"truth": torch.Tensor(truth)}) test_result = evaluator.get_metric() f1 = round(test_result['f'] * 100, 2) pre = round(test_result['pre'] * 100, 2) rec = round(test_result['rec'] * 100, 2) return {"F1": f1, "precision": pre, "recall": rec}
def test_case3(self): number_labels = 4 # bio tag fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, -0.3782, 0.8240], [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, -0.3562, -1.4116], [ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, 2.0023, 0.7075], [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, 0.3832, -0.1540], [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, -1.3508, -0.9513], [ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, -0.0842, -0.4294]], [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, -1.4138, -0.8853], [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, -1.0726, 0.0364], [ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, -0.8836, -0.9320], [ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, -1.6857, 1.1571], [ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, -0.5837, 1.0184], [ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, -0.9025, 0.0864]]]) bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) fastnlp_bio_metric({'pred': bio_sequence, 'seq_len': torch.LongTensor([6, 6])}, {'target': bio_target}) expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric())
def test(self, filepath): """ 传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 分词文件应该为:: 1 编者按 编者按 NN O 11 nmod:topic 2 : : PU O 11 punct 3 7月 7月 NT DATE 4 compound:nn 4 12日 12日 NT DATE 11 nmod:tmod 5 , , PU O 11 punct 1 这 这 DT O 3 det 2 款 款 M O 1 mark:clf 3 飞行 飞行 NN O 8 nsubj 4 从 从 P O 5 case 5 外型 外型 NN O 8 nmod:prep 以空行分割两个句子,有内容的每行有7列。 :param filepath: str, 文件路径路径。 :return: float, float, float. 分别f1, precision, recall. """ tag_proc = self._dict['tag_proc'] cws_model = self.pipeline.pipeline[-2].model pipeline = self.pipeline.pipeline[:-2] pipeline.insert(1, tag_proc) pp = Pipeline(pipeline) reader = ConllCWSReader() # te_filename = '/home/hyan/ctb3/test.conllx' te_dataset = reader.load(filepath) pp(te_dataset) from ..core.tester import Tester from ..core.metrics import SpanFPreRecMetric tester = Tester(data=te_dataset, model=cws_model, metrics=SpanFPreRecMetric(tag_proc.get_vocab()), batch_size=64, verbose=0) eval_res = tester.test() f1 = eval_res['SpanFPreRecMetric']['f'] pre = eval_res['SpanFPreRecMetric']['pre'] rec = eval_res['SpanFPreRecMetric']['rec'] # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) return {"F1": f1, "precision": pre, "recall": rec}
if args.init == 'uniform': nn.init.xavier_uniform_(p) print_info('xavier uniform init:{}'.format(n)) elif args.init == 'norm': print_info('xavier norm init:{}'.format(n)) nn.init.xavier_normal_(p) except: print_info(n) exit(1208) print_info('{}init pram{}'.format('*' * 15, '*' * 15)) loss = LossInForward() encoding_type = 'bmeso' f1_metric = SpanFPreRecMetric(vocabs['label'], pred='pred', target='target', seq_len='seq_len', encoding_type=encoding_type) acc_metric = AccuracyMetric( pred='pred', target='target', seq_len='seq_len', ) acc_metric.set_metric_name('label_acc') metrics = [f1_metric, acc_metric] if args.self_supervised: chars_acc_metric = AccuracyMetric(pred='chars_pred', target='chars_target', seq_len='seq_len') chars_acc_metric.set_metric_name('chars_acc') metrics.append(chars_acc_metric)
def test_case4(self): # bmes tag def _generate_samples(): target = [] seq_len = [] vocab = Vocabulary(unknown=None, padding=None) for i in range(3): target_i = [] seq_len_i = 0 for j in range(1, 10): word_len = np.random.randint(1, 5) seq_len_i += word_len if word_len == 1: target_i.append('S') else: target_i.append('B') target_i.extend(['M'] * (word_len - 2)) target_i.append('E') vocab.add_word_lst(target_i) target.append(target_i) seq_len.append(seq_len_i) target_ = np.zeros((3, max(seq_len))) for i in range(3): target_i = [vocab.to_index(t) for t in target[i]] target_[i, :seq_len[i]] = target_i return target_, target, seq_len, vocab def get_eval(raw_target, pred, vocab, seq_len): pred = pred.argmax(dim=-1).tolist() tp = 0 gold = 0 seg = 0 pred_target = [] for i in range(len(seq_len)): tags = [vocab.to_word(p) for p in pred[i][:seq_len[i]]] spans = [] prev_bmes_tag = None for idx, tag in enumerate(tags): if tag in ('B', 'S'): spans.append([idx, idx]) elif tag in ('M', 'E') and prev_bmes_tag in ('B', 'M'): spans[-1][1] = idx else: spans.append([idx, idx]) prev_bmes_tag = tag tmp = [] for span in spans: if span[1] - span[0] > 0: tmp.extend(['B'] + ['M'] * (span[1] - span[0] - 1) + ['E']) else: tmp.append('S') pred_target.append(tmp) for i in range(len(seq_len)): raw_pred = pred_target[i] start = 0 for j in range(seq_len[i]): if raw_target[i][j] in ('E', 'S'): flag = True for k in range(start, j + 1): if raw_target[i][k] != raw_pred[k]: flag = False break if flag: tp += 1 start = j + 1 gold += 1 if raw_pred[j] in ('E', 'S'): seg += 1 pre = round(tp / seg, 6) rec = round(tp / gold, 6) return { 'f': round(2 * pre * rec / (pre + rec), 6), 'pre': pre, 'rec': rec } target, raw_target, seq_len, vocab = _generate_samples() pred = torch.randn(3, max(seq_len), 4) expected_metric = get_eval(raw_target, pred, vocab, seq_len) metric = SpanFPreRecMetric(vocab, encoding_type='bmes') metric({ 'pred': pred, 'seq_len': torch.LongTensor(seq_len) }, {'target': torch.from_numpy(target)}) # print(metric.get_metric(reset=False)) # print(expected_metric) metric_value = metric.get_metric() for key, value in expected_metric.items(): self.assertAlmostEqual(value, metric_value[key], places=5)
def tese_case3(self): from fastNLP.core.vocabulary import Vocabulary from collections import Counter from fastNLP.core.metrics import SpanFPreRecMetric # 与allennlp测试能否正确计算f metric # def generate_allen_tags(encoding_type, number_labels=4): vocab = {} for i in range(number_labels): label = str(i) for tag in encoding_type: if tag == 'O': if tag not in vocab: vocab['O'] = len(vocab) + 1 continue vocab['{}-{}'.format( tag, label)] = len(vocab) + 1 # 其实表达的是这个的count return vocab number_labels = 4 # bio tag fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bio_vocab.word_count = Counter( generate_allen_tags('BIO', number_labels)) fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) bio_sequence = torch.FloatTensor( [[[ -0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, 0.0470, 0.0971 ], [ -0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, 0.7987, -0.3970 ], [ 0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, 0.6880, 1.4348 ], [ -0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, -1.6876, -0.8917 ], [ -0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, 1.4217, 0.2622 ]], [[ 0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, 1.3592, -0.8973 ], [ 0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, -0.4025, -0.3417 ], [ -0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, 0.2861, -0.3966 ], [ -0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, 0.0213, 1.4777 ], [ -1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, 1.3024, 0.2001 ]]]) bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], [5., 6., 8., 6., 0.]]) fastnlp_bio_metric( { 'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5]) }, {'target': bio_target}) expect_bio_res = { 'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775, 'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, 'f': 0.12499999999994846 } self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) #bmes tag bmes_sequence = torch.FloatTensor( [[[ 0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, -0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, -0.3505, -0.6002 ], [ 0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641, 1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002, 0.4439, 0.2514 ], [ -0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696, -1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753, 0.5802, -0.0516 ], [ -0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121, 4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390, -0.9242, -0.0333 ], [ 0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, 0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, -0.3779, -0.3195 ]], [[ -0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, 0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, -0.1103, 0.4417 ], [ -0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049, -0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631, -0.6386, 0.2797 ], [ 0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947, 0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522, 1.1237, -1.5385 ], [ 0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977, -0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376, -0.0591, -0.2461 ], [ -0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097, 2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, -0.7344, -1.2046 ]]]) bmes_target = torch.LongTensor([[9., 6., 1., 9., 15.], [6., 15., 6., 15., 5.]]) fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bmes_vocab.word_count = Counter( generate_allen_tags('BMES', number_labels)) fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') fastnlp_bmes_metric( { 'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20]) }, {'target': bmes_target}) expect_bmes_res = { 'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, 'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, 'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, 'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, 'pre': 0.499999999999995, 'rec': 0.499999999999995 } self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res)
def train(train_data_path, dev_data_path, checkpoint=None, save=None): # load config train_param = ConfigSection() model_param = ConfigSection() ConfigLoader().load_config(cfgfile, { "train": train_param, "model": model_param }) print("config loaded") # Data Loader print("loading training set...") dataset = ConllxDataLoader().load(train_data_path, return_dataset=True) print("loading dev set...") dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True) print(dataset) print("================= dataset ready =====================") dataset.rename_field("tag", "truth") dev_data.rename_field("tag", "truth") vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") tag_proc = VocabIndexerProcessor("truth", is_input=True) seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len") vocab_proc(dataset) tag_proc(dataset) seq_len_proc(dataset) # index dev set word_vocab, tag_vocab = vocab_proc.vocab, tag_proc.vocab dev_data.apply(lambda ins: [word_vocab.to_index(w) for w in ins["words"]], new_field_name="word_seq") dev_data.apply(lambda ins: [tag_vocab.to_index(w) for w in ins["truth"]], new_field_name="truth") dev_data.apply(lambda ins: len(ins["word_seq"]), new_field_name="word_seq_origin_len") # set input & target dataset.set_input("word_seq", "word_seq_origin_len", "truth") dev_data.set_input("word_seq", "word_seq_origin_len", "truth") dataset.set_target("truth", "word_seq_origin_len") dev_data.set_target("truth", "word_seq_origin_len") # dataset.set_is_target(tag_ids=True) model_param["vocab_size"] = vocab_proc.get_vocab_size() model_param["num_classes"] = tag_proc.get_vocab_size() print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) # define a model if checkpoint is None: # pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) pre_trained = None model = AdvSeqLabel(model_param, id2words=None, emb=pre_trained) print(model) else: model = torch.load(checkpoint) # call trainer to train trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric( tag_proc.vocab, pred="predict", target="truth", seq_lens="word_seq_origin_len"), dev_data=dev_data, metric_key="f", use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) trainer.train(load_best_model=True) # save model & pipeline model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") pp = Pipeline( [vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} torch.save(save_dict, os.path.join(save, "model_pp.pkl")) print("pipeline saved")
def train(checkpoint=None): # load config train_param = ConfigSection() model_param = ConfigSection() ConfigLoader().load_config(cfgfile, { "train": train_param, "model": model_param }) print("config loaded") # Data Loader dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") print(dataset) print("dataset transformed") dataset.rename_field("tag", "truth") vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") tag_proc = VocabIndexerProcessor("truth") seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) vocab_proc(dataset) tag_proc(dataset) seq_len_proc(dataset) dataset.set_input("word_seq", "word_seq_origin_len", "truth") dataset.set_target("truth", "word_seq_origin_len") print("processors defined") # dataset.set_is_target(tag_ids=True) model_param["vocab_size"] = vocab_proc.get_vocab_size() model_param["num_classes"] = tag_proc.get_vocab_size() print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) # define a model if checkpoint is None: # pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) pre_trained = None model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) print(model) else: model = torch.load(checkpoint) # call trainer to train trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric( tag_proc.vocab, pred="predict", target="truth", seq_lens="word_seq_origin_len"), dev_data=dataset, metric_key="f", use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") trainer.train(load_best_model=True) # save model & pipeline model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} torch.save(save_dict, "model_pp.pkl") print("pipeline saved") torch.save(model, "./save/best_model.pkl")
args = {"word_emb_dim": 300, "rnn_hidden_units": 300, "num_classes": len(vocab[1]), "init_embedding": embedding, "vocab_size": len(vocab[0])} print(args) model = SeqLabelingForSLSTM(args) if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) metrics, metric_key = None, None if arg.dataset == "ner": metrics = SpanFPreRecMetric(vocab[1], pred='predict', target='truth', seq_lens='word_seq_origin_len') metric_key = "f" elif arg.dataset == "pos": metrics = AccuracyMetric(pred='predict', target='truth', seq_lens='word_seq_origin_len') metric_key = "acc" trainer = Trainer( train_data=train_dataset, model=model, loss=None, # loss=CrossEntropyLoss(pred='predict', target='truth'), metrics=metrics, n_epochs=20, batch_size=arg.batch_size, print_every=1,