def __init__(self, bert_config, tagset_size, embedding_dim, hidden_dim, d_model, n_head, d_k, d_v, dropout_ratio, dropout1, use_cuda=False): super(BERT_LSTM_CRF, self).__init__() self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.word_embeds = BertModel.from_pretrained(bert_config) self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout_ratio) self.dropout1 = nn.Dropout(p=dropout1) self.crf = CRF(target_size=tagset_size, average_batch=True, use_cuda=use_cuda) self.liner = nn.Linear(hidden_dim * 2, tagset_size + 2) self.tagset_size = tagset_size
def __init__(self, bert_config, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=False): super(BertLstmCrf, self).__init__() self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.word_embeds = BertModel.from_pretrained(bert_config) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=rnn_layers, bidirectional=True, dropout=dropout_ratio, batch_first=True) self.rnn_layers = rnn_layers self.dropout1 = nn.Dropout(p=dropout1) self.crf = CRF(target_size=tagset_size, average_batch=True, use_cuda=use_cuda) self.liner = nn.Linear(hidden_dim * 2, tagset_size + 2) self.tagset_size = tagset_size
def __init__(self, args, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=False): super(BERT_LSTM_CRF, self).__init__() self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.word_embeds = BertModel(config=BertConfig.from_json_file(args.bert_config_json)) # print(self.word_embeds) self.word_embeds.load_state_dict(torch.load('./ckpts/9134_bert_weight.bin')) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=rnn_layers, bidirectional=True, dropout=dropout_ratio, batch_first=True) self.rnn_layers = rnn_layers self.dropout1 = nn.Dropout(p=dropout1) self.crf = CRF(target_size=tagset_size, average_batch=True, use_cuda=use_cuda) self.liner = nn.Linear(hidden_dim*2, tagset_size+2) self.tagset_size = tagset_size
class BERT_LSTM_CRF(nn.Module): def __init__(self, bert_config, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda): super(BERT_LSTM_CRF, self).__init__() self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.word_embeds = BertModel.from_pretrained(bert_config) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=rnn_layers, bidirectional=True, dropout=dropout_ratio, batch_first=True) self.rnn_layers = rnn_layers self.dropout1 = nn.Dropout(p=dropout1) self.crf = CRF(target_size=tagset_size, average_batch=True, use_cuda=use_cuda) self.liner = nn.Linear(hidden_dim*2, tagset_size+2) self.tagset_size = tagset_size self.use_cuda = use_cuda def rand_init_hidden(self, batch_size): if self.use_cuda: return Variable( torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)).cuda(), Variable( torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)).cuda() else: return Variable( torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)), Variable( torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)) def get_output_score(self, sentence, attention_mask=None): batch_size = sentence.size(0) seq_length = sentence.size(1) embeds, _ = self.word_embeds(sentence, attention_mask=attention_mask, output_all_encoded_layers=False) hidden = self.rand_init_hidden(batch_size) # if embeds.is_cuda: # hidden = (i.cuda() for i in hidden) lstm_out, hidden = self.lstm(embeds, hidden) lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim * 2) d_lstm_out = self.dropout1(lstm_out) l_out = self.liner(d_lstm_out) lstm_feats = l_out.contiguous().view(batch_size, seq_length, -1) return lstm_feats def forward(self, sentence, masks): lstm_feats = self.get_output_score(sentence) scores, tag_seq = self.crf._viterbi_decode(lstm_feats, masks.byte()) return tag_seq def neg_log_likelihood_loss(self, sentence, mask, tags): lstm_feats = self.get_output_score(sentence) loss_value = self.crf.neg_log_likelihood_loss(lstm_feats, mask, tags) batch_size = lstm_feats.size(0) loss_value /= float(batch_size) return loss_value
class BERT_LSTM_CRF(nn.Module): """ bert_lstm_crf model """ def __init__(self, bert_config, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=False): super(BERT_LSTM_CRF, self).__init__() self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.word_embeds = BertModel.from_pretrained(bert_config) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=rnn_layers, bidirectional=True, dropout=dropout_ratio, batch_first=True) self.rnn_layers = rnn_layers self.dropout1 = nn.Dropout(p=dropout1) self.crf = CRF(target_size=tagset_size, average_batch=True, use_cuda=use_cuda) self.liner = nn.Linear(hidden_dim * 2, tagset_size + 2) self.tagset_size = tagset_size def rand_init_hidden(self, batch_size): """ random initialize hidden variable """ return Variable( torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim, device=DEVICE)), Variable( torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim, device=DEVICE)) def forward(self, sentence, attention_mask=None): ''' args: sentence (word_seq_len, batch_size) : word-level representation of sentence hidden: initial hidden state return: crf output (word_seq_len, batch_size, tag_size, tag_size), hidden ''' batch_size = sentence.size(0) seq_length = sentence.size(1) embeds, _ = self.word_embeds(sentence, attention_mask=attention_mask, output_all_encoded_layers=False) hidden = self.rand_init_hidden(batch_size) lstm_out, hidden = self.lstm(embeds, hidden) lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim * 2) d_lstm_out = self.dropout1(lstm_out) l_out = self.liner(d_lstm_out) lstm_feats = l_out.contiguous().view(batch_size, seq_length, -1) return lstm_feats def loss(self, feats, mask, tags): """ feats: size=(batch_size, seq_len, tag_size) mask: size=(batch_size, seq_len) tags: size=(batch_size, seq_len) :return: """ loss_value = self.crf.neg_log_likelihood_loss(feats, mask, tags) batch_size = feats.size(0) loss_value /= float(batch_size) return loss_value
class BertLstmCrf(nn.Module): def __init__(self, bert_config, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=False): super(BertLstmCrf, self).__init__() self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.word_embeds = BertModel.from_pretrained(bert_config) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=rnn_layers, bidirectional=True, dropout=dropout_ratio, batch_first=True) self.rnn_layers = rnn_layers self.dropout1 = nn.Dropout(p=dropout1) self.crf = CRF(target_size=tagset_size, average_batch=True, use_cuda=use_cuda) self.liner = nn.Linear(hidden_dim * 2, tagset_size + 2) self.tagset_size = tagset_size def rand_init_hidden(self, batch_size): return Variable(torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)), \ Variable(torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)) def forward(self, sentence, attention_mask=None): embeds, _ = self.word_embeds(sentence, attention_mask=attention_mask, output_all_encoded_layers=False) hidden = self.rand_init_hidden(sentence.size(0)) if embeds.is_cuda: hidden = (hidden[0].cuda(), hidden[1].cuda()) lstm_out, hidden = self.lstm(embeds, hidden) lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim * 2) d_lstm_out = self.dropout1(lstm_out) l_out = self.liner(d_lstm_out) lstm_feats = l_out.contiguous().view(sentence.size(0), sentence.size(1), -1) return lstm_feats def loss(self, feats, mask, tags): loss_value = self.crf.neg_log_likelihood_loss(feats, mask, tags) batch_size = feats.size(0) loss_value /= float(batch_size) return loss_value
def __init__(self, args, num_labels): super(BertCRF, self).__init__() # bert模型 self.bert_config = AutoConfig.from_pretrained( os.path.join(args.bert_path, 'config.json')) self.bert = AutoModel.from_pretrained(os.path.join( args.bert_path, 'pytorch_model.bin'), config=self.bert_config) # 每个token进行分类 self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob) self.classifier = nn.Linear(self.bert_config.hidden_size, num_labels) # 送入CRF进行预测 self.crf = CRF(num_tags=num_labels, batch_first=True)
def __init__(self, args, num_labels): super().__init__() self.device = torch.device( 'cuda:{}'.format(args.device) if torch.cuda.is_available() and args.device != '-1' else 'cpu') self.output_size = num_labels self.rnn_layers = args.lstm_rnn_layers self.hidden_dim = args.lstm_hidden_dim self.bidirectional = args.lstm_bidirectional # bert模型,作为词嵌入 self.bert_config = AutoConfig.from_pretrained( os.path.join(args.bert_path, 'config.json')) self.word_embeds = AutoModel.from_pretrained(os.path.join( args.bert_path, 'pytorch_model.bin'), config=self.bert_config) # lstm layers self.lstm = nn.LSTM( input_size=self.bert_config.hidden_size, # input_size: The number of expected features in the input `x` hidden_size=self. hidden_dim, # The number of features in the hidden state `h` num_layers=self.rnn_layers, # Number of recurrent layers. E.g., setting ``num_layers=2``would mean stacking two LSTMs # together to form a `stacked LSTM`,with the second LSTM taking in outputs of the first # LSTM and computing the final results. Default: 1 batch_first=True, bidirectional=self.bidirectional ) # If True, becomes a bidirectional LSTM. # dropout layer self.dropout = nn.Dropout(args.lstm_dropout) # linear layer # Maps the output of the LSTM into tag space. if self.bidirectional: self.hidden2tag = nn.Linear(self.hidden_dim * 2, self.output_size) else: self.hidden2tag = nn.Linear(self.hidden_dim, self.output_size) # crf layer self.crf = CRF(num_tags=self.output_size, batch_first=True)
})) char_highway = Highway( Config({ 'num_layers': 2, 'size': char_cnn.output_size, 'activation': 'selu' })) lstm = LSTM( Config({ 'input_size': word_embed.output_size + char_cnn.output_size, 'hidden_size': train_args['lstm_hidden_size'], 'forget_bias': 1.0, 'batch_first': True, 'bidirectional': True })) crf = CRF(Config({'label_vocab': label_vocab})) output_linear = Linear( Config({ 'in_features': lstm.output_size, 'out_features': len(label_vocab) })) word_embed.load_state_dict(state['model']['word_embed']) char_cnn.load_state_dict(state['model']['char_cnn']) char_highway.load_state_dict(state['model']['char_highway']) lstm.load_state_dict(state['model']['lstm']) crf.load_state_dict(state['model']['crf']) output_linear.load_state_dict(state['model']['output_linear']) lstm_crf = LstmCrf(token_vocab=token_vocab, label_vocab=label_vocab, char_vocab=char_vocab, word_embedding=word_embed,
elif labels_row[idx] == "I": if idx > 0 and labels_row[idx - 1] == "B": continue else: keywords_tmp.append(words_row[idx]) keywords.append(keywords_tmp) return keywords if __name__ == "__main__": dataset = load_data(".\\kpwr-1.1\\*\\result.csv") features = create_features_list(dataset) dataset["features"] = features train, test = train_test_split(dataset) CRF.train(train['features'], train['label_base']) preds = CRF.test(test['features']) keywords_true = test['base_keywords_in_text'] keywords_pred = get_keywords_from_labels(test['base_words_list'], preds) prec_h, rec_h, f1_h = evaluator.hard_evaluation(keywords_true, keywords_pred) prec_s, rec_s, f1_s = evaluator.soft_evaluation(keywords_true, keywords_pred) print( f"Soft evalution: Precission: {np.mean(prec_s)*100}, Recall: {np.mean(rec_s)*100}, F1Score: {np.mean(f1_s)*100}" ) print(
'filters': charcnn_filters })) char_highway = Highway(Config({ 'num_layers': 2, 'size': char_cnn.output_size, 'activation': 'selu' })) lstm = LSTM(Config({ 'input_size': word_embed.output_size + char_cnn.output_size, 'hidden_size': args.lstm_hidden_size, 'forget_bias': 1.0, 'batch_first': True, 'bidirectional': True })) crf = CRF(Config({ 'label_vocab': label_vocab })) output_linear = Linear(Config({ 'in_features': lstm.output_size, 'out_features': len(label_vocab) })) # LSTM CRF Model lstm_crf = LstmCrf( token_vocab=token_vocab, label_vocab=label_vocab, char_vocab=char_vocab, word_embedding=word_embed, char_embedding=char_cnn, crf=crf, lstm=lstm,
class BERT_LSTM_CRF(nn.Module): """ bert_lstm_crf model bert_model=BertModel(config=BertConfig.from_json_file(args.bert_config_json)) """ def __init__(self, args, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=False): super(BERT_LSTM_CRF, self).__init__() self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.word_embeds = BertModel(config=BertConfig.from_json_file(args.bert_config_json)) # print(self.word_embeds) self.word_embeds.load_state_dict(torch.load('./ckpts/9134_bert_weight.bin')) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=rnn_layers, bidirectional=True, dropout=dropout_ratio, batch_first=True) self.rnn_layers = rnn_layers self.dropout1 = nn.Dropout(p=dropout1) self.crf = CRF(target_size=tagset_size, average_batch=True, use_cuda=use_cuda) self.liner = nn.Linear(hidden_dim*2, tagset_size+2) self.tagset_size = tagset_size def rand_init_hidden(self, batch_size): """ random initialize hidden variable """ return Variable( torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)), Variable( torch.randn(2 * self.rnn_layers, batch_size, self.hidden_dim)) def forward(self, sentence, attention_mask=None): ''' args: sentence (word_seq_len, batch_size) : word-level representation of sentence hidden: initial hidden state return: crf output (word_seq_len, batch_size, tag_size, tag_size), hidden ''' batch_size = sentence.size(0) seq_length = sentence.size(1) embeds, _ = self.word_embeds(sentence, attention_mask=attention_mask, output_all_encoded_layers=False) # print(embeds,_) hidden = self.rand_init_hidden(batch_size) # if embeds.is_cuda: # hidden = (i.cuda() for i in hidden) # embeds=(embeds,dim=0,keepdim=True) lstm_out, hidden = self.lstm(embeds) lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim*2) d_lstm_out = self.dropout1(lstm_out) l_out = self.liner(d_lstm_out) lstm_feats = l_out.contiguous().view(batch_size, seq_length, -1) return lstm_feats def loss(self, feats, mask, tags): """ feats: size=(batch_size, seq_len, tag_size) mask: size=(batch_size, seq_len) tags: size=(batch_size, seq_len) :return: """ loss_value = self.crf.neg_log_likelihood_loss(feats, mask, tags) batch_size = feats.size(0) loss_value /= float(batch_size) return loss_value
y_pred = [] test_loss = Variable(torch.Tensor([0])) with torch.no_grad(): for obs_seq, state_seq in zip(X, y): y_pred.append(model.decode(obs_seq)) loss = model.nll_loss(obs_seq, state_seq) test_loss += loss test_loss /= X.shape[0] test_acc = model.score(y, y_pred) print('Test loss: {:.4f}\tTest Accuracy: {:.4f}'.format( test_loss.item(), test_acc)) if __name__ == '__main__': X_train, X_test, y_train, y_test = load_dataset() args = Args() model = CRF(n_states=2, n_obs=6) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) for epoch in range(args.n_epochs): train(model, X_train, y_train, optimizer, epoch, args) test(model, X_test, y_test)
char_highway = Highway( Config({ 'num_layers': 2, 'size': char_cnn.output_size, 'activation': 'selu' })) lstm = LSTM( Config({ 'input_size': word_embed_1.output_size + char_cnn.output_size, 'hidden_size': args.lstm_hidden_size, 'forget_bias': 1.0, 'batch_size': True, 'bidirectional': True })) # CRF layer for task 1 crf_1 = CRF(Config({'label_vocab': label_vocab_1})) # CRF layer for task 2 crf_2 = CRF(Config({'label_vocab': label_vocab_2})) # Linear layers for task 1 shared_output_linear_1 = Linear( Config({ 'in_features': lstm.output_size, 'out_features': len(label_vocab_1) })) spec_output_linear_1_1 = Linear( Config({ 'in_features': lstm.output_size, 'out_features': len(label_vocab_1) })) spec_output_linear_1_2 = Linear( Config({
charcnn_filters = [[int(f.split(',')[0]), int(f.split(',')[1])] for f in args.charcnn_filters.split(';')] char_embed = CharCNN(len(char_vocab), args.char_embed_dim, filters=charcnn_filters) char_hw = Highway(char_embed.output_size, layer_num=args.charhw_layer, activation=args.charhw_func) feat_dim = word_embed.embedding_dim + char_embed.output_size lstm = LSTM(feat_dim, args.lstm_hidden_size, batch_first=True, bidirectional=True, forget_bias=args.lstm_forget_bias) crf = CRF(label_size=len(label_vocab) + 2) linear = Linears(in_features=lstm.output_size, out_features=len(label_vocab), hiddens=[lstm.output_size // 2]) lstm_crf = LstmCrf(token_vocab, label_vocab, char_vocab, word_embedding=word_embed, char_embedding=char_embed, crf=crf, lstm=lstm, univ_fc_layer=linear, embed_dropout_prob=args.feat_dropout, lstm_dropout_prob=args.lstm_dropout, char_highway=char_hw if args.use_highway else None) if use_gpu:
if labels_row[idx] == "B" and idx + 1 < len( words_row) and labels_row[idx + 1] == "I": keywords_tmp.append(words_row[idx] + " " + words_row[idx + 1]) elif labels_row[idx] == "I": if idx > 0 and labels_row[idx - 1] == "B": continue else: keywords_tmp.append(words_row[idx]) keywords.append(keywords_tmp) return keywords dataset = load_processed_data() dataset = create_features_list(dataset) train, test = CRF.split_data(dataset) CRF.train(train['features'], train['label_base']) preds = CRF.test(test['features']) keywords_true = test['base_keywords_in_text'] keywords_pred = get_keywords_from_labels(test['base_words_list'], preds) prec_h, rec_h, f1_h = evaluator.hard_evaluation(keywords_true, keywords_pred) prec_s, rec_s, f1_s = evaluator.soft_evaluation(keywords_true, keywords_pred) print("Sotf evalution: Precission: %.2f, Recall: %.2f, F1Score: %.2f" % (np.mean(precision_soft_list) * 100, np.mean(recall_soft_list) * 100, np.mean(f1_soft_list) * 100)) print("Hard evalution: Precission: %.2f, Recall: %.2f, F1Score: %.2f" % (np.mean(precision_hard_list) * 100, np.mean(recall_hard_list) * 100, np.mean(f1_hard_list) * 100))
train_args['word_embed_dim'], sparse=True, padding_idx=C.PAD_INDEX) char_embed = CharCNN(len(char_vocab), train_args['char_embed_dim'], filters=charcnn_filters) char_hw = Highway(char_embed.output_size, layer_num=train_args['charhw_layer'], activation=train_args['charhw_func']) feat_dim = word_embed.embedding_dim + char_embed.output_size lstm = LSTM(feat_dim, train_args['lstm_hidden_size'], batch_first=True, bidirectional=True, forget_bias=train_args['lstm_forget_bias']) crf = CRF(label_size=len(label_vocab) + 2) linear = Linear(in_features=lstm.output_size, out_features=len(label_vocab)) lstm_crf = LstmCrf(token_vocab, label_vocab, char_vocab, word_embedding=word_embed, char_embedding=char_embed, crf=crf, lstm=lstm, univ_fc_layer=linear, embed_dropout_prob=train_args['feat_dropout'], lstm_dropout_prob=train_args['lstm_dropout'], char_highway=char_hw if train_args['use_highway'] else None) word_embed.load_state_dict(state['model']['word_embed']) char_embed.load_state_dict(state['model']['char_embed'])
charcnn_filters = [[int(f.split(',')[0]), int(f.split(',')[1])] for f in args.charcnn_filters.split(';')] char_embed = CharCNN(len(char_vocab), args.char_embed_dim, filters=charcnn_filters) char_hw = Highway(char_embed.output_size, layer_num=args.charhw_layer, activation=args.charhw_func) feat_dim = args.word_embed_dim + char_embed.output_size lstm = LSTM(feat_dim, args.lstm_hidden_size, batch_first=True, bidirectional=True, forget_bias=args.lstm_forget_bias) crf_1 = CRF(label_size=len(label_vocab_1) + 2) crf_2 = CRF(label_size=len(label_vocab_2) + 2) # Linear layers for task 1 shared_linear_1 = Linear(in_features=lstm.output_size, out_features=len(label_vocab_1)) spec_linear_1_1 = Linear(in_features=lstm.output_size, out_features=len(label_vocab_1)) spec_linear_1_2 = Linear(in_features=lstm.output_size, out_features=len(label_vocab_1)) # Linear layers for task 2 shared_linear_2 = Linear(in_features=lstm.output_size, out_features=len(label_vocab_2)) spec_linear_2_1 = Linear(in_features=lstm.output_size, out_features=len(label_vocab_2)) spec_linear_2_2 = Linear(in_features=lstm.output_size, out_features=len(label_vocab_2))
hid_dim, act_num_classes, n_layers=args.d_layers, dropout=args.dropout).to(device) topic_l_decoder = Decoder(emb_dim, hid_dim, topic_num_classes, n_layers=args.d_layers, dropout=args.dropout).to(device) topic_r_decoder = Decoder(emb_dim, hid_dim, topic_num_classes, n_layers=args.d_layers, dropout=args.dropout).to(device) act_CRF = CRF(act_num_classes, device).to(device) topic_CRF = CRF(topic_num_classes, device).to(device) seq2seq = Seq2Seq(encoder, act_l_decoder, act_r_decoder, topic_l_decoder, topic_r_decoder, act_CRF, topic_CRF, hid_dim, act_num_classes, topic_num_classes, device).to(device) optimizer = optim.Adam(seq2seq.parameters(), lr=args.lr, eps=args.eps, weight_decay=args.decay) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max') print(seq2seq) def epoch_time(start_time, end_time): elapsed_time = end_time - start_time elapsed_mins = int(elapsed_time / 60)
# Obtain tag2idx, idx2tag, word2idx, idx2word tag2idx, idx2tag = helper.tag2idx, helper.idx2tag word2idx, idx2word = helper.word2idx, helper.idx2word # Load embedding embed_mat = helper.load_embed(args.embed_path, args.embed_dim) #%% Define model vocab_size = len(word2idx) if args.max_vocab_size: vocab_size = args.max_vocab_size if args.model == 'crf': model = CRF(vocab_size=vocab_size, embed_dim=args.embed_dim, num_tags=len(tag2idx), embed_matrix=embed_mat) if args.model == 'lstm_crf': model = BiLSTM_CRF(vocab_size=vocab_size, embed_dim=args.embed_dim, hidden_dim=args.hidden_dim, num_tags=len(tag2idx), embed_matrix=embed_mat) #%% Opimizer and scheduler # loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) # Set "from_logits=True" may be more numerically stable total_steps = len(list(train_batches)) * args.epochs warm_steps = int(total_steps * args.warm_frac) lr_scheduler = utils.LrScheduler(args.lr, warm_steps, total_steps)