def get_random_madcat_test_sample(lh=30): global MadcatTestDataset if MadcatTestDataset is None: # hardcoded for now(!) line_height = lh h_pad = 0 v_pad = 0 line_img_transforms = imagetransforms.Compose([ imagetransforms.Scale(new_h=line_height), imagetransforms.InvertBlackWhite(), imagetransforms.Pad(h_pad, v_pad), imagetransforms.ToTensor(), ]) MadcatTestDataset = MadcatDataset("/lfs2/srawls/madcat", "test", ArabicAlphabet(), line_height, line_img_transforms) return MadcatTestDataset[random.randint(0, len(MadcatTestDataset) - 1)]
def get_random_iam_test_sample(lh=30): global IamTestDataset if IamTestDataset is None: # hardcoded for now(!) line_height = lh h_pad = 0 v_pad = 0 line_img_transforms = imagetransforms.Compose([ imagetransforms.Scale(new_h=line_height), imagetransforms.InvertBlackWhite(), imagetransforms.Pad(h_pad, v_pad), imagetransforms.ToTensor(), ]) IamTestDataset = IAMDataset( "/nfs/isicvlnas01/users/srawls/ocr-dev/data/iam/", "test", EnglishAlphabet( lm_units_path= "/nfs/isicvlnas01/users/jmathai//experiments/lm_grid_search/iam-grid-data/IAM-LM-4-kndiscount-interpolate-0.9/IAM-LM/units.txt" ), line_height, line_img_transforms) return IamTestDataset[random.randint(0, len(IamTestDataset) - 1)]
import sys import cv2 iam = True if iam: model_path = "/nas/home/srawls/ocr/experiments/iam-baseline-best_model.pth" lm_path = "/nfs/isicvlnas01/users/jmathai/experiments/iam_lm_augment_more_data/IAM-LM/" line_height = 120 line_img_transforms = imagetransforms.Compose([ imagetransforms.Scale(new_h=line_height), imagetransforms.InvertBlackWhite(), imagetransforms.ToTensor(), ]) lm_units = os.path.join(lm_path, 'units.txt') lm_words = os.path.join(lm_path, 'words.txt') lm_wfst = os.path.join(lm_path, 'TLG.fst') # Set seed for consistancy torch.manual_seed(7) torch.cuda.manual_seed_all(7) model = CnnOcrModel.FromSavedWeights(model_path) model.eval() model.init_lm(lm_wfst, lm_words, acoustic_weight=0.8)
def main(): logger.info("Starting training\n\n") sys.stdout.flush() args = get_args() snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth" best_model_path = args.snapshot_prefix + "-best_model.pth" line_img_transforms = imagetransforms.Compose([ imagetransforms.Scale(new_h=args.line_height), imagetransforms.InvertBlackWhite(), imagetransforms.ToTensor(), ]) # Setup cudnn benchmarks for faster code torch.backends.cudnn.benchmark = False train_dataset = OcrDataset(args.datadir, "train", line_img_transforms) validation_dataset = OcrDataset(args.datadir, "validation", line_img_transforms) train_dataloader = DataLoader(train_dataset, args.batch_size, num_workers=4, sampler=GroupedSampler(train_dataset, rand=True), collate_fn=SortByWidthCollater, pin_memory=True, drop_last=True) validation_dataloader = DataLoader(validation_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler( validation_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) n_epochs = args.nepochs lr_alpha = args.lr snapshot_every_n_iterations = args.snapshot_num_iterations if args.load_from_snapshot is not None: model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot) else: model = CnnOcrModel(num_in_channels=1, input_line_height=args.line_height, lstm_input_dim=args.lstm_input_dim, num_lstm_layers=args.num_lstm_layers, num_lstm_hidden_units=args.num_lstm_units, p_lstm_dropout=0.5, alphabet=train_dataset.alphabet, multigpu=True) # Set training mode on all sub-modules model.train() ctc_loss = CTCLoss().cuda() iteration = 0 best_val_wer = float('inf') optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=args.patience, min_lr=args.min_lr) wer_array = [] cer_array = [] loss_array = [] lr_points = [] iteration_points = [] epoch_size = len(train_dataloader) for epoch in range(1, n_epochs + 1): epoch_start = datetime.datetime.now() # First modify main OCR model for batch in train_dataloader: sys.stdout.flush() iteration += 1 iteration_start = datetime.datetime.now() loss = train(batch, model, ctc_loss, optimizer) elapsed_time = datetime.datetime.now() - iteration_start loss = loss / args.batch_size loss_array.append(loss) logger.info( "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s" % (iteration, iteration % epoch_size, epoch_size, epoch, loss, pretty_print_timespan(elapsed_time))) # Do something with loss, running average, plot to some backend server, etc if iteration % snapshot_every_n_iterations == 0: logger.info("Testing on validation set") val_loss, val_cer, val_wer = test_on_val( validation_dataloader, model, ctc_loss) # Reduce learning rate on plateau early_exit = False lowered_lr = False if scheduler.step(val_wer): lowered_lr = True lr_points.append(iteration / snapshot_every_n_iterations) if scheduler.finished: early_exit = True # for bookeeping only lr_alpha = max(lr_alpha * scheduler.factor, scheduler.min_lr) logger.info( "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" % (val_loss, val_cer, val_wer)) torch.save( { 'iteration': iteration, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_hyper_params': model.get_hyper_params(), 'cur_lr': lr_alpha, 'val_loss': val_loss, 'val_cer': val_cer, 'val_wer': val_wer, 'line_height': args.line_height }, snapshot_path) # plotting lr_change on wer, cer and loss. wer_array.append(val_wer) cer_array.append(val_cer) iteration_points.append(iteration / snapshot_every_n_iterations) if val_wer < best_val_wer: logger.info( "Best model so far, copying snapshot to best model file" ) best_val_wer = val_wer shutil.copyfile(snapshot_path, best_model_path) logger.info("Running WER: %s" % str(wer_array)) logger.info("Done with validation, moving on.") if early_exit: logger.info("Early exit") sys.exit(0) if lowered_lr: logger.info( "Switching to best model parameters before continuing with lower LR" ) weights = torch.load(best_model_path) model.load_state_dict(weights['state_dict']) elapsed_time = datetime.datetime.now() - epoch_start logger.info("\n------------------") logger.info("Done with epoch, elapsed time = %s" % pretty_print_timespan(elapsed_time)) logger.info("------------------\n") #writer.close() logger.info("Done.")
def main(): args = get_args() model = CnnOcrModel.FromSavedWeights(args.model_path) model.eval() line_img_transforms = [ imagetransforms.Scale(new_h=model.input_line_height) ] # Only do for grayscale if model.num_in_channels == 1: line_img_transforms.append(imagetransforms.InvertBlackWhite()) # For right-to-left languages if model.rtl: line_img_transforms.append(imagetransforms.HorizontalFlip()) line_img_transforms.append(imagetransforms.ToTensor()) line_img_transforms = imagetransforms.Compose(line_img_transforms) have_lm = (args.lm_path is not None) and (args.lm_path != "") if have_lm: lm_units = os.path.join(args.lm_path, 'units.txt') lm_words = os.path.join(args.lm_path, 'words.txt') lm_wfst = os.path.join(args.lm_path, 'TLG.fst') test_dataset = OcrDataset(args.datadir, "test", line_img_transforms) # Set seed for consistancy torch.manual_seed(7) torch.cuda.manual_seed_all(7) if have_lm: model.init_lm(lm_wfst, lm_words, lm_units, acoustic_weight=0.8) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size, num_workers=args.num_data_threads, sampler=GroupedSampler(test_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=True, drop_last=False) hyp_output = [] hyp_lm_output = [] print("About to process test set. Total # iterations is %d." % len(test_dataloader)) # No need for backprop during validation test with torch.no_grad(): for idx, (input_tensor, target, input_widths, target_widths, metadata) in enumerate(test_dataloader): sys.stdout.write(".") sys.stdout.flush() # Wrap inputs in PyTorch Variable class input_tensor = input_tensor.cuda(async=True) # Call model model_output, model_output_actual_lengths = model( input_tensor, input_widths) # Do LM-free decoding hyp_transcriptions = model.decode_without_lm( model_output, model_output_actual_lengths, uxxxx=True) # Optionally, do LM decoding if have_lm: hyp_transcriptions_lm = model.decode_with_lm( model_output, model_output_actual_lengths, uxxxx=True) for i in range(len(hyp_transcriptions)): hyp_output.append( (metadata['utt-ids'][i], hyp_transcriptions[i])) if have_lm: hyp_lm_output.append( (metadata['utt-ids'][i], hyp_transcriptions_lm[i])) hyp_out_file = os.path.join(args.outdir, "hyp-chars.txt") if have_lm: hyp_lm_out_file = os.path.join(args.outdir, "hyp-lm-chars.txt") print("") print("Done. Now writing output files:") print("\t%s" % hyp_out_file) if have_lm: print("\t%s" % hyp_lm_out_file) with open(hyp_out_file, 'w') as fh: for uttid, hyp in hyp_output: fh.write("%s (%s)\n" % (hyp, uttid)) if have_lm: with open(hyp_lm_out_file, 'w') as fh: for uttid, hyp in hyp_lm_output: fh.write("%s (%s)\n" % (hyp, uttid))
def main(): args = get_args() model = CnnOcrModel.FromSavedWeights(args.model_path) model.eval() line_img_transforms = [] if args.cvtGray: line_img_transforms.append(imagetransforms.ConvertGray()) line_img_transforms.append( imagetransforms.Scale(new_h=model.input_line_height)) # Only do for grayscale if model.num_in_channels == 1: line_img_transforms.append(imagetransforms.InvertBlackWhite()) # For right-to-left languages # if model.rtl: # line_img_transforms.append(imagetransforms.HorizontalFlip()) line_img_transforms.append(imagetransforms.ToTensor()) line_img_transforms = imagetransforms.Compose(line_img_transforms) test_dataset = OcrDataset(args.datadir, "test", line_img_transforms, max_allowed_width=1e5) # Set seed for consistancy torch.manual_seed(7) torch.cuda.manual_seed_all(7) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size, num_workers=args.num_data_threads, sampler=GroupedSampler(test_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=True, drop_last=False) print("About to process test set. Total # iterations is %d." % len(test_dataloader)) # Setup seperate process + queue for handling CPU-portion of decoding input_queue = multiprocessing.Queue() decoding_p = multiprocessing.Process(target=decode_thread, args=(input_queue, args.outdir, model.alphabet, args.lm_path)) decoding_p.start() # No need for backprop during validation test start_time = datetime.datetime.now() with torch.no_grad(): for idx, (input_tensor, target, input_widths, target_widths, metadata) in enumerate(test_dataloader): # Wrap inputs in PyTorch Variable class input_tensor = input_tensor.cuda(async=True) # Call model model_output, model_output_actual_lengths = model( input_tensor, input_widths) # Put model output on the queue for background process to decode input_queue.put( (model_output.cpu(), model_output_actual_lengths, metadata)) # Now we just need to wait for decode thread to finish input_queue.put(None) input_queue.close() decoding_p.join() end_time = datetime.datetime.now() print("Decoding took %f seconds" % (end_time - start_time).total_seconds())
def main(): args = get_args() model = CnnOcrModel.FromSavedWeights(args.model_path) model.eval() line_img_transforms = imagetransforms.Compose([ imagetransforms.Scale(new_h=model.input_line_height), imagetransforms.InvertBlackWhite(), imagetransforms.ToTensor(), ]) have_lm = (args.lm_path is not None) and (args.lm_path != "") if have_lm: lm_units = os.path.join(args.lm_path, 'units.txt') lm_words = os.path.join(args.lm_path, 'words.txt') lm_wfst = os.path.join(args.lm_path, 'TLG.fst') test_dataset = OcrDataset(args.datadir, "test", line_img_transforms) # Set seed for consistancy torch.manual_seed(7) torch.cuda.manual_seed_all(7) if have_lm: model.init_lm(lm_wfst, lm_words, lm_units, acoustic_weight=0.8) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_data_threads, sampler=GroupedSampler(test_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=True, drop_last=False) hyp_output = [] hyp_lm_output = [] ref_output = [] print("About to process test set. Total # iterations is %d." % len(test_dataloader)) for idx, (input_tensor, target, input_widths, target_widths, metadata) in enumerate(test_dataloader): sys.stdout.write(".") sys.stdout.flush() # Wrap inputs in PyTorch Variable class input_tensor = Variable(input_tensor.cuda(async=True), volatile=True) target = Variable(target, volatile=True) target_widths = Variable(target_widths, volatile=True) input_widths = Variable(input_widths, volatile=True) # Call model model_output, model_output_actual_lengths = model(input_tensor, input_widths) # Do LM-free decoding hyp_transcriptions = model.decode_without_lm(model_output, model_output_actual_lengths, uxxxx=True) # Optionally, do LM decoding if have_lm: hyp_transcriptions_lm = model.decode_with_lm(model_output, model_output_actual_lengths, uxxxx=True) cur_target_offset = 0 target_np = target.data.numpy() for i in range(len(hyp_transcriptions)): ref_transcription = form_target_transcription( target_np[cur_target_offset:(cur_target_offset + target_widths.data[i])], model.alphabet) cur_target_offset += target_widths.data[i] hyp_output.append((metadata['utt-ids'][i], hyp_transcriptions[i])) if have_lm: hyp_lm_output.append((metadata['utt-ids'][i], hyp_transcriptions_lm[i])) ref_output.append((metadata['utt-ids'][i], ref_transcription)) hyp_out_file = os.path.join(args.outdir, "hyp-chars.txt") ref_out_file = os.path.join(args.outdir, "ref-chars.txt") if have_lm: hyp_lm_out_file = os.path.join(args.outdir, "hyp-lm-chars.txt") print("") print("Done. Now writing output files:") print("\t%s" % hyp_out_file) if have_lm: print("\t%s" % hyp_lm_out_file) print("\t%s" % ref_out_file) with open(hyp_out_file, 'w') as fh: for uttid, hyp in hyp_output: fh.write("%s (%s)\n" % (hyp, uttid)) if have_lm: with open(hyp_lm_out_file, 'w') as fh: for uttid, hyp in hyp_lm_output: fh.write("%s (%s)\n" % (hyp, uttid)) with open(ref_out_file, 'w') as fh: for uttid, ref in ref_output: fh.write("%s (%s)\n" % (ref, uttid))
def main(): logger.info("Starting training\n\n") sys.stdout.flush() args = get_args() snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth" best_model_path = args.snapshot_prefix + "-best_model.pth" line_img_transforms = [] #if args.num_in_channels == 3: # line_img_transforms.append(imagetransforms.ConvertColor()) # Always convert color for the augmentations to work (for now) # Then alter convert back to grayscale if needed line_img_transforms.append(imagetransforms.ConvertColor()) # Data augmentations (during training only) if args.daves_augment: line_img_transforms.append(daves_augment.ImageAug()) if args.synth_input: # Randomly rotate image from -2 degrees to +2 degrees line_img_transforms.append( imagetransforms.Randomize(0.3, imagetransforms.RotateRandom(-2, 2))) # Choose one of methods to blur/pixel-ify image (or don't and choose identity) line_img_transforms.append( imagetransforms.PickOne([ imagetransforms.TessBlockConv(kernel_val=1, bias_val=1), imagetransforms.TessBlockConv(rand=True), imagetransforms.Identity(), ])) aug_cn = iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5) line_img_transforms.append( imagetransforms.Randomize(0.5, lambda x: aug_cn.augment_image(x))) # With some probability, choose one of: # Grayscale: convert to grayscale and add back into color-image with random alpha # Emboss: Emboss image with random strength # Invert: Invert colors of image per-channel aug_gray = iaa.Grayscale(alpha=(0.0, 1.0)) aug_emboss = iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)) aug_invert = iaa.Invert(1, per_channel=True) aug_invert2 = iaa.Invert(0.1, per_channel=False) line_img_transforms.append( imagetransforms.Randomize( 0.3, imagetransforms.PickOne([ lambda x: aug_gray.augment_image(x), lambda x: aug_emboss.augment_image(x), lambda x: aug_invert.augment_image(x), lambda x: aug_invert2.augment_image(x) ]))) # Randomly try to crop close to top/bottom and left/right of lines # For now we are just guessing (up to 5% of ends and up to 10% of tops/bottoms chopped off) if args.tight_crop: # To make sure padding is reasonably consistent, we first rsize image to target line height # Then add padding to this version of image # Below it will get resized again to target line height line_img_transforms.append( imagetransforms.Randomize( 0.9, imagetransforms.Compose([ imagetransforms.Scale(new_h=args.line_height), imagetransforms.PadRandom(pxl_max_horizontal=30, pxl_max_vertical=10) ]))) else: line_img_transforms.append( imagetransforms.Randomize(0.2, imagetransforms.CropHorizontal(.05))) line_img_transforms.append( imagetransforms.Randomize(0.2, imagetransforms.CropVertical(.1))) #line_img_transforms.append(imagetransforms.Randomize(0.2, # imagetransforms.PickOne([imagetransforms.MorphErode(3), imagetransforms.MorphDilate(3)]) # )) # Make sure to do resize after degrade step above line_img_transforms.append(imagetransforms.Scale(new_h=args.line_height)) if args.cvtGray: line_img_transforms.append(imagetransforms.ConvertGray()) # Only do for grayscale if args.num_in_channels == 1: line_img_transforms.append(imagetransforms.InvertBlackWhite()) if args.stripe: line_img_transforms.append( imagetransforms.Randomize( 0.3, imagetransforms.AddRandomStripe(val=0, strip_width_from=1, strip_width_to=4))) line_img_transforms.append(imagetransforms.ToTensor()) line_img_transforms = imagetransforms.Compose(line_img_transforms) # Setup cudnn benchmarks for faster code torch.backends.cudnn.benchmark = False if len(args.datadir) == 1: train_dataset = OcrDataset(args.datadir[0], "train", line_img_transforms) validation_dataset = OcrDataset(args.datadir[0], "validation", line_img_transforms) else: train_dataset = OcrDatasetUnion(args.datadir, "train", line_img_transforms) validation_dataset = OcrDatasetUnion(args.datadir, "validation", line_img_transforms) if args.test_datadir is not None: if args.test_outdir is None: print( "Error, must specify both --test-datadir and --test-outdir together" ) sys.exit(1) if not os.path.exists(args.test_outdir): os.makedirs(args.test_outdir) line_img_transforms_test = imagetransforms.Compose([ imagetransforms.Scale(new_h=args.line_height), imagetransforms.ToTensor() ]) test_dataset = OcrDataset(args.test_datadir, "test", line_img_transforms_test) n_epochs = args.nepochs lr_alpha = args.lr snapshot_every_n_iterations = args.snapshot_num_iterations if args.load_from_snapshot is not None: model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot) print( "Overriding automatically learned alphabet with pre-saved model alphabet" ) if len(args.datadir) == 1: train_dataset.alphabet = model.alphabet validation_dataset.alphabet = model.alphabet else: train_dataset.alphabet = model.alphabet validation_dataset.alphabet = model.alphabet for ds in train_dataset.datasets: ds.alphabet = model.alphabet for ds in validation_dataset.datasets: ds.alphabet = model.alphabet else: model = CnnOcrModel(num_in_channels=args.num_in_channels, input_line_height=args.line_height, rds_line_height=args.rds_line_height, lstm_input_dim=args.lstm_input_dim, num_lstm_layers=args.num_lstm_layers, num_lstm_hidden_units=args.num_lstm_units, p_lstm_dropout=0.5, alphabet=train_dataset.alphabet, multigpu=True) # Setting dataloader after we have a chnae to (maybe!) over-ride the dataset alphabet from a pre-trained model train_dataloader = DataLoader(train_dataset, args.batch_size, num_workers=4, sampler=GroupedSampler(train_dataset, rand=True), collate_fn=SortByWidthCollater, pin_memory=True, drop_last=True) if args.max_val_size > 0: validation_dataloader = DataLoader(validation_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler( validation_dataset, max_items=args.max_val_size, fixed_rand=True), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) else: validation_dataloader = DataLoader(validation_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler( validation_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) if args.test_datadir is not None: test_dataloader = DataLoader(test_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler(test_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) # Set training mode on all sub-modules model.train() ctc_loss = CTCLoss().cuda() iteration = 0 best_val_wer = float('inf') optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha, weight_decay=args.weight_decay) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=args.patience, min_lr=args.min_lr) wer_array = [] cer_array = [] loss_array = [] lr_points = [] iteration_points = [] epoch_size = len(train_dataloader) do_test_write = False for epoch in range(1, n_epochs + 1): epoch_start = datetime.datetime.now() # First modify main OCR model for batch in train_dataloader: sys.stdout.flush() iteration += 1 iteration_start = datetime.datetime.now() loss = train(batch, model, ctc_loss, optimizer) elapsed_time = datetime.datetime.now() - iteration_start loss = loss / args.batch_size loss_array.append(loss) logger.info( "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s" % (iteration, iteration % epoch_size, epoch_size, epoch, loss, pretty_print_timespan(elapsed_time))) # Only turn on test-on-testset when cer is starting to get non-random if iteration % snapshot_every_n_iterations == 0: logger.info("Testing on validation set") val_loss, val_cer, val_wer = test_on_val( validation_dataloader, model, ctc_loss) if val_cer < 0.5: do_test_write = True if args.test_datadir is not None and ( iteration % snapshot_every_n_iterations == 0) and do_test_write: out_hyp_outdomain_file = os.path.join( args.test_outdir, "hyp-%07d.outdomain.utf8" % iteration) out_hyp_indomain_file = os.path.join( args.test_outdir, "hyp-%07d.indomain.utf8" % iteration) out_meta_file = os.path.join(args.test_outdir, "hyp-%07d.meta" % iteration) test_on_val_writeout(test_dataloader, model, out_hyp_outdomain_file) test_on_val_writeout(validation_dataloader, model, out_hyp_indomain_file) with open(out_meta_file, 'w') as fh_out: fh_out.write("%d,%f,%f,%f\n" % (iteration, val_cer, val_wer, val_loss)) # Reduce learning rate on plateau early_exit = False lowered_lr = False if scheduler.step(val_wer): lowered_lr = True lr_points.append(iteration / snapshot_every_n_iterations) if scheduler.finished: early_exit = True # for bookeeping only lr_alpha = max(lr_alpha * scheduler.factor, scheduler.min_lr) logger.info( "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" % (val_loss, val_cer, val_wer)) torch.save( { 'iteration': iteration, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_hyper_params': model.get_hyper_params(), 'rtl': args.rtl, 'cur_lr': lr_alpha, 'val_loss': val_loss, 'val_cer': val_cer, 'val_wer': val_wer, 'line_height': args.line_height }, snapshot_path) # plotting lr_change on wer, cer and loss. wer_array.append(val_wer) cer_array.append(val_cer) iteration_points.append(iteration / snapshot_every_n_iterations) if val_wer < best_val_wer: logger.info( "Best model so far, copying snapshot to best model file" ) best_val_wer = val_wer shutil.copyfile(snapshot_path, best_model_path) logger.info("Running WER: %s" % str(wer_array)) logger.info("Done with validation, moving on.") if early_exit: logger.info("Early exit") sys.exit(0) if lowered_lr: logger.info( "Switching to best model parameters before continuing with lower LR" ) weights = torch.load(best_model_path) model.load_state_dict(weights['state_dict']) elapsed_time = datetime.datetime.now() - epoch_start logger.info("\n------------------") logger.info("Done with epoch, elapsed time = %s" % pretty_print_timespan(elapsed_time)) logger.info("------------------\n") #writer.close() logger.info("Done.")