def eval(epoch, config, model, validating_data, logger, dev_dataset, visualizer=None): model.eval() total_loss = 0 total_dist = 0 total_word = 0 batch_steps = len(validating_data) for step, (inputs, targets, origin_length, inputs_length, targets_length) in enumerate(validating_data): if config.training.num_gpu > 0: inputs, inputs_length = inputs.cuda(), inputs_length.cuda() targets, targets_length = targets.cuda(), targets_length.cuda() max_inputs_length = origin_length.max().item() max_targets_length = targets_length.max().item() inputs = inputs[:, :max_inputs_length, :] targets = targets[:, :max_targets_length] preds = model.recognize(inputs, inputs_length) transcripts = [ targets.cpu().numpy()[i][:targets_length[i].item()] for i in range(targets.size(0)) ] print(''.join( [dev_dataset.index2word.get(index) for index in preds[0]])) print(''.join( [dev_dataset.index2word.get(index) for index in transcripts[0]])) dist, num_words = computer_cer(preds, transcripts) total_dist += dist total_word += num_words cer = total_dist / total_word * 100 if step % config.training.show_interval == 0: process = step / batch_steps * 100 logger.info('-Validation-Epoch:%d(%.5f%%), CER: %.5f %%' % (epoch, process, cer)) val_loss = total_loss / batch_steps logger.info( '-Validation-Epoch:%4d, AverageLoss:%.5f, AverageCER: %.5f %%' % (epoch, val_loss, cer)) if visualizer is not None: visualizer.add_scalar('cer', cer, epoch) return cer
def test(config, model, test_dataset, validate_data, logger): model.eval() total_dist = 0 total_word = 0 batch_steps = len(validate_data) for step, (inputs, targets, origin_length, inputs_length, targets_length) in enumerate(validate_data): if config.training.num_gpu > 0: inputs, inputs_length = inputs.cuda(), inputs_length.cuda() targets, targets_length = targets.cuda(), targets_length.cuda() max_inputs_length = inputs_length.max().item() max_targets_length = targets_length.max().item() # inputs = inputs[:, :max_inputs_length, :] targets = targets[:, :max_targets_length] preds = model.recognize(inputs, inputs_length) transcripts = [ targets.cpu().numpy()[i][:targets_length[i].item()] for i in range(targets.size(0)) ] print(''.join( [test_dataset.index2word.get(index) for index in preds[0]])) print(''.join( [test_dataset.index2word.get(index) for index in transcripts[0]])) dist, num_words = computer_cer(preds, transcripts) total_dist += dist total_word += num_words cer = total_dist / total_word * 100 if step % config.training.show_interval == 0: process = step / batch_steps * 100 logger.info('-Validation-Epoch:%d(%.5f%%), CER: %.5f %%' % (1, process, cer)) logger.info('-Validation-Epoch:%4d, AverageCER: %.5f %%' % (1, cer)) return cer
if l != previous: new_labels.append(l) previous = l # 删除blank new_labels = [l for l in new_labels if l != blank] return new_labels pred_outs = [remove_blank(pred) for pred in preds] targets_y = [remove_blank(label) for label in targets_y] print(''.join( dev_dataset.index2word.get(index) for index in pred_outs[0])) print(''.join( dev_dataset.index2word.get(index) for index in targets_y[0])) diff, total = computer_cer(pred_outs, targets_y) # print(diff) # print(total) total_diff += diff total_nums += total wer = total_diff / total_nums * 100 print('ctc model wer : {}%'.format(wer)) if wer < old_wer: old_wer = wer print('complete trained model save!') save_path = os.path.join(home_dir, 'ctc_model') if not os.path.exists(save_path): os.makedirs(save_path) torch.save( model.state_dict(), 'ctc_model/{}_{:.4f}_enecoder_model'.format(
def eval_cer(config, model, validating_data): rle_decoder = GreedyDecoder(int2base, blank_index=0) base_decoder = GreedyDecoder(int2base, blank_index=0) model.eval() total_loss = 0 total_dist = 0 total_rle_dist = 0 total_word = 0 total_rle_word = 0 total_match, total_insert, total_delete, total_mismatch, total_length = 0, 0, 0, 0, 0 batch_steps = len(validating_data) for step, (regions, inputs, inputs_length, targets, targets_length, rle_bases, rle_bases_length, rles, rles_length) in enumerate(validating_data): if config.training.num_gpu > 0: inputs, inputs_length = inputs.cuda(), inputs_length.cuda() targets, targets_length = targets.cuda(), targets_length.cuda() rle_bases, rle_bases_length, rles, rles_length = rle_bases.cuda( ), rle_bases_length.cuda(), rles.cuda(), rles_length.cuda() max_inputs_length = inputs_length.max().item() max_targets_length = targets_length.max().item() max_rle_bases_length = rle_bases_length.max().item() max_rles_length = rles_length.max().item() inputs = inputs[:, :max_inputs_length, :] # [N, max_inputs_length, c] targets = targets[:, :max_targets_length] # [N, max_targets_length] rle_bases = rle_bases[:, :max_rle_bases_length] rles = rles[:, :max_rles_length] base_logits, rle_logits = model.recognize(inputs, inputs_length) # [B,L,o] pred_base_strings, offset = base_decoder.decode( base_logits, inputs_length) pred_rle_strings, offset = rle_decoder.decode(rle_logits, inputs_length) pred_base_strings = [v.upper() for v in pred_base_strings] pred_rle_strings = [v.upper() for v in pred_rle_strings] # print(pred_base_strings, pred_rle_strings) # print('preds') # print(preds) base_targets = [ rle_bases.cpu().numpy()[i][:rle_bases_length[i].item()].tolist() for i in range(rle_bases.size(0)) ] base_transcripts = [] for i in range(rle_bases.size(0)): rle_target_seq = "" for v in base_targets[i]: rle_target_seq += int2base[v] base_transcripts.append(''.join(rle_target_seq)) rle_targets = [ targets.cpu().numpy()[i][:targets_length[i].item()].tolist() for i in range(rles.size(0)) ] rle_transcripts = [] for i in range(targets.size(0)): rle_target_seq = "" for v in rle_targets[i]: rle_target_seq += int2base[v] rle_transcripts.append(''.join(rle_target_seq)) # # print('transcripts') # # print(transcripts) # dist, num_words = computer_cer(pred_base_strings, base_transcripts) total_dist += dist total_word += num_words # print('base:',pred_base_strings, transcripts) rle_dist, rle_num_words = computer_cer(pred_rle_strings, rle_transcripts) total_rle_dist += rle_dist total_rle_word += rle_num_words # print('rle:',pred_rle_strings, rle_transcripts) # print() match, insert, delete, mismatch, batch_length, insert_set = calculate_identity( pred_rle_strings, rle_transcripts) for idx in insert_set: print(regions[idx]) print("pred :", pred_rle_strings[idx]) print("label:", rle_transcripts[idx]) print(inputs.cpu()[idx][:inputs_length[idx]][:, :7]) print('-' * 10) total_match += match total_insert += insert total_delete += delete total_mismatch += mismatch total_length += batch_length # cer = total_dist / total_word * 100 rle_cer = total_rle_dist / total_rle_word * 100 if step % config.training.show_interval == 0: process = step / batch_steps * 100 print('-Validation:(%.5f%%), CER: %.5f %%, Rle CER: %.5f %%' % (process, cer, rle_cer)) val_loss = total_loss / (step + 1) print( '-Validation:, AverageLoss:%.5f, AverageCER: %.5f %%, AverageRleCER: %.5f %%, Mismatch: %.5f %%, Insertion: %.5f %%, Deletion: %.5f %%' % (val_loss, cer, rle_cer, total_mismatch / total_length * 100, total_insert / total_length * 100, total_delete / total_length * 100))