Ejemplo n.º 1
0
def load_model(fpath, fdim=128, hu=256, gpu=False):

    mweights = torch.load(fpath)

    # For now hardcode these params
    # Later need to read them from serialized model file
    line_height = 30
    h_pad = 0
    v_pad = 0

    alphabet = EnglishAlphabet()
    #alphabet = ArabicAlphabet()

    model = CnnOcrModel(num_in_channels=1,
                        input_line_height=line_height + 2 * v_pad,
                        lstm_input_dim=fdim,
                        num_lstm_layers=3,
                        num_lstm_hidden_units=hu,
                        p_lstm_dropout=0.5,
                        alphabet=alphabet,
                        multigpu=True,
                        verbose=False,
                        gpu=gpu)

    model.load_state_dict(mweights['state_dict'])
    model.eval()

    return model
Ejemplo n.º 2
0
line_img_transforms = imagetransforms.Compose([
    imagetransforms.Scale(new_h=line_height),
    imagetransforms.InvertBlackWhite(),
    imagetransforms.ToTensor(),
])

lm_units = os.path.join(lm_path, 'units.txt')
lm_words = os.path.join(lm_path, 'words.txt')
lm_wfst = os.path.join(lm_path, 'TLG.fst')

# Set seed for consistancy
torch.manual_seed(7)
torch.cuda.manual_seed_all(7)

model = CnnOcrModel.FromSavedWeights(model_path)
model.eval()

model.init_lm(lm_wfst, lm_words, acoustic_weight=0.8)

print("Starting on data")

#img = '/nas/home/srawls/aida-ocr/rex-text-detection/text-regions/004291text-region-Idx6-Pr0.781796395778656.jpg'
#for img_file in [img]:
base_dir = '/nas/home/srawls/aida-ocr/rex-text-detection/text-regions/'

bad_imgs = ['003991text-region-Idx57-Pr0.7098091840744019.jpg']

hyp_out_file = os.path.join(os.environ["TMPDIR"], "hyp-chars.txt")
hyp_lm_out_file = os.path.join(os.environ["TMPDIR"], "hyp-lm-chars.txt")
Ejemplo n.º 3
0
if args.lm_path != "":
    lm_units = os.path.join(args.lm_path, 'units.txt')
    lm_words = os.path.join(args.lm_path, 'words.txt')
    lm_wfst = os.path.join(args.lm_path, 'TLG.fst')
else:
    lm_units = ""

# need language arg (or read from model more likely!!!)
#alphabet = EnglishAlphabet(lm_units_path=lm_units)

if lm_units == "":
    lm_units = "/nas/home/srawls/madcat-units.txt"
alphabet = ArabicAlphabet(lm_units_path=lm_units)

if args.model == "CnnOcrModel":
    model = CnnOcrModel.FromSavedWeights(model_path, alphabet=alphabet)
else:
    raise TypeError("model not recognized.")
model.eval()

if args.lm_path != "":
    logger.info("About to init LM with acoustic_weight:%s" % args.acoustic_weight)
     model.init_lm(lm_wfst, lm_words, acoustic_weight=args.acoustic_weight)
     logger.info("Done init'ing LM")


hyp_output = []
hyp_lm_output = []

iteration = 0
Ejemplo n.º 4
0
from models.cnnlstm import CnnOcrModel

from warpctc_pytorch import CTCLoss
import time
import numpy as np

batchsize = 64
height = 30

print("batch size = %d" % batchsize)
print("height = %d" % height)

model = CnnOcrModel(num_in_channels=1,
                    input_line_height=height,
                    lstm_input_dim=128,
                    num_lstm_layers=3,
                    num_lstm_hidden_units=512,
                    p_lstm_dropout=0.5,
                    alphabet=range(120),
                    multigpu=False)

model.train()

criterion = CTCLoss().cuda()

# Setup fake constant target
target = torch.autograd.Variable(torch.IntTensor(batchsize * 5))
for idx in range(target.size(0)):
    target[idx] = 32
target_widths = torch.autograd.Variable(torch.IntTensor([5] * batchsize))

print("torch.backends.cudnn.enabled = %s" % torch.backends.cudnn.enabled)
Ejemplo n.º 5
0
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = imagetransforms.Compose([
        imagetransforms.Scale(new_h=args.line_height),
        imagetransforms.InvertBlackWhite(),
        imagetransforms.ToTensor(),
    ])

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    train_dataset = OcrDataset(args.datadir, "train", line_img_transforms)
    validation_dataset = OcrDataset(args.datadir, "validation",
                                    line_img_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    validation_dataloader = DataLoader(validation_dataset,
                                       args.batch_size,
                                       num_workers=0,
                                       sampler=GroupedSampler(
                                           validation_dataset, rand=False),
                                       collate_fn=SortByWidthCollater,
                                       pin_memory=False,
                                       drop_last=False)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
    else:
        model = CnnOcrModel(num_in_channels=1,
                            input_line_height=args.line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Do something with loss, running average, plot to some backend server, etc

            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)
                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")
Ejemplo n.º 6
0
def main():
    args = get_args()

    model = CnnOcrModel.FromSavedWeights(args.model_path)
    model.eval()

    line_img_transforms = [
        imagetransforms.Scale(new_h=model.input_line_height)
    ]

    # Only do for grayscale
    if model.num_in_channels == 1:
        line_img_transforms.append(imagetransforms.InvertBlackWhite())

    # For right-to-left languages
    if model.rtl:
        line_img_transforms.append(imagetransforms.HorizontalFlip())

    line_img_transforms.append(imagetransforms.ToTensor())

    line_img_transforms = imagetransforms.Compose(line_img_transforms)

    have_lm = (args.lm_path is not None) and (args.lm_path != "")

    if have_lm:
        lm_units = os.path.join(args.lm_path, 'units.txt')
        lm_words = os.path.join(args.lm_path, 'words.txt')
        lm_wfst = os.path.join(args.lm_path, 'TLG.fst')

    test_dataset = OcrDataset(args.datadir, "test", line_img_transforms)

    # Set seed for consistancy
    torch.manual_seed(7)
    torch.cuda.manual_seed_all(7)

    if have_lm:
        model.init_lm(lm_wfst, lm_words, lm_units, acoustic_weight=0.8)

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_data_threads,
        sampler=GroupedSampler(test_dataset, rand=False),
        collate_fn=SortByWidthCollater,
        pin_memory=True,
        drop_last=False)

    hyp_output = []
    hyp_lm_output = []

    print("About to process test set. Total # iterations is %d." %
          len(test_dataloader))

    # No need for backprop during validation test
    with torch.no_grad():
        for idx, (input_tensor, target, input_widths, target_widths,
                  metadata) in enumerate(test_dataloader):
            sys.stdout.write(".")
            sys.stdout.flush()

            # Wrap inputs in PyTorch Variable class
            input_tensor = input_tensor.cuda(async=True)

            # Call model
            model_output, model_output_actual_lengths = model(
                input_tensor, input_widths)

            # Do LM-free decoding
            hyp_transcriptions = model.decode_without_lm(
                model_output, model_output_actual_lengths, uxxxx=True)

            # Optionally, do LM decoding
            if have_lm:
                hyp_transcriptions_lm = model.decode_with_lm(
                    model_output, model_output_actual_lengths, uxxxx=True)

            for i in range(len(hyp_transcriptions)):
                hyp_output.append(
                    (metadata['utt-ids'][i], hyp_transcriptions[i]))

                if have_lm:
                    hyp_lm_output.append(
                        (metadata['utt-ids'][i], hyp_transcriptions_lm[i]))

    hyp_out_file = os.path.join(args.outdir, "hyp-chars.txt")

    if have_lm:
        hyp_lm_out_file = os.path.join(args.outdir, "hyp-lm-chars.txt")

    print("")
    print("Done. Now writing output files:")
    print("\t%s" % hyp_out_file)

    if have_lm:
        print("\t%s" % hyp_lm_out_file)

    with open(hyp_out_file, 'w') as fh:
        for uttid, hyp in hyp_output:
            fh.write("%s (%s)\n" % (hyp, uttid))

    if have_lm:
        with open(hyp_lm_out_file, 'w') as fh:
            for uttid, hyp in hyp_lm_output:
                fh.write("%s (%s)\n" % (hyp, uttid))
Ejemplo n.º 7
0
import GPUtil

import loggy
logger = loggy.setup_custom_logger('root', "train_cnn_lstm.py")

#alphabet = EnglishAlphabet("/nfs/isicvlnas01/users/jmathai/Experiments/iam_lm_augment_more_data/IAM-LM/units.txt")
alphabet = EnglishAlphabet(
    "/nfs/isicvlnas01/users/jmathai//experiments/lm_grid_search/iam-grid-data/IAM-LM-4-kndiscount-interpolate-0.9/IAM-LM/units.txt"
)
LINEH = 240

model = CnnOcrModel(verbose=True,
                    num_in_channels=1,
                    input_line_height=LINEH,
                    lstm_input_dim=128,
                    num_lstm_layers=3,
                    num_lstm_hidden_units=512,
                    p_lstm_dropout=0.5,
                    alphabet=alphabet,
                    multigpu=True)

print("")
print("")
print("")
torch.cuda.empty_cache()
GPUtil.showUtilization()

torch.backends.cudnn.benchmark = True

# Setup fake constant target
batchsize = 4
Ejemplo n.º 8
0
def main():
    args = get_args()

    model = CnnOcrModel.FromSavedWeights(args.model_path)
    model.eval()

    line_img_transforms = []

    if args.cvtGray:
        line_img_transforms.append(imagetransforms.ConvertGray())

    line_img_transforms.append(
        imagetransforms.Scale(new_h=model.input_line_height))

    # Only do for grayscale
    if model.num_in_channels == 1:
        line_img_transforms.append(imagetransforms.InvertBlackWhite())

    # For right-to-left languages


#    if model.rtl:
#        line_img_transforms.append(imagetransforms.HorizontalFlip())

    line_img_transforms.append(imagetransforms.ToTensor())

    line_img_transforms = imagetransforms.Compose(line_img_transforms)

    test_dataset = OcrDataset(args.datadir,
                              "test",
                              line_img_transforms,
                              max_allowed_width=1e5)

    # Set seed for consistancy
    torch.manual_seed(7)
    torch.cuda.manual_seed_all(7)

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_data_threads,
        sampler=GroupedSampler(test_dataset, rand=False),
        collate_fn=SortByWidthCollater,
        pin_memory=True,
        drop_last=False)

    print("About to process test set. Total # iterations is %d." %
          len(test_dataloader))

    # Setup seperate process + queue for handling CPU-portion of decoding
    input_queue = multiprocessing.Queue()
    decoding_p = multiprocessing.Process(target=decode_thread,
                                         args=(input_queue, args.outdir,
                                               model.alphabet, args.lm_path))
    decoding_p.start()

    # No need for backprop during validation test
    start_time = datetime.datetime.now()
    with torch.no_grad():
        for idx, (input_tensor, target, input_widths, target_widths,
                  metadata) in enumerate(test_dataloader):
            # Wrap inputs in PyTorch Variable class
            input_tensor = input_tensor.cuda(async=True)

            # Call model
            model_output, model_output_actual_lengths = model(
                input_tensor, input_widths)

            # Put model output on the queue for background process to decode
            input_queue.put(
                (model_output.cpu(), model_output_actual_lengths, metadata))

    # Now we just need to wait for decode thread to finish
    input_queue.put(None)
    input_queue.close()
    decoding_p.join()

    end_time = datetime.datetime.now()

    print("Decoding took %f seconds" % (end_time - start_time).total_seconds())
Ejemplo n.º 9
0
def main():
    args = get_args()

    model = CnnOcrModel.FromSavedWeights(args.model_path)
    model.eval()

    line_img_transforms = imagetransforms.Compose([
        imagetransforms.Scale(new_h=model.input_line_height),
        imagetransforms.InvertBlackWhite(),
        imagetransforms.ToTensor(),
    ])


    have_lm = (args.lm_path is not None) and (args.lm_path != "")

    if have_lm:
        lm_units = os.path.join(args.lm_path, 'units.txt')
        lm_words = os.path.join(args.lm_path, 'words.txt')
        lm_wfst = os.path.join(args.lm_path, 'TLG.fst')


    test_dataset = OcrDataset(args.datadir, "test", line_img_transforms)

    # Set seed for consistancy
    torch.manual_seed(7)
    torch.cuda.manual_seed_all(7)


    if have_lm:
        model.init_lm(lm_wfst, lm_words, lm_units, acoustic_weight=0.8)


    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=args.batch_size,
                                                  num_workers=args.num_data_threads,
                                                  sampler=GroupedSampler(test_dataset, rand=False),
                                                  collate_fn=SortByWidthCollater,
                                                  pin_memory=True,
                                                  drop_last=False)


    hyp_output = []
    hyp_lm_output = []
    ref_output = []


    print("About to process test set. Total # iterations is %d." % len(test_dataloader))

    for idx, (input_tensor, target, input_widths, target_widths, metadata) in enumerate(test_dataloader):
        sys.stdout.write(".")
        sys.stdout.flush()

        # Wrap inputs in PyTorch Variable class
        input_tensor = Variable(input_tensor.cuda(async=True), volatile=True)
        target = Variable(target, volatile=True)
        target_widths = Variable(target_widths, volatile=True)
        input_widths = Variable(input_widths, volatile=True)

        # Call model
        model_output, model_output_actual_lengths = model(input_tensor, input_widths)

        # Do LM-free decoding
        hyp_transcriptions = model.decode_without_lm(model_output, model_output_actual_lengths, uxxxx=True)

        # Optionally, do LM decoding
        if have_lm:
            hyp_transcriptions_lm = model.decode_with_lm(model_output, model_output_actual_lengths, uxxxx=True)



        cur_target_offset = 0
        target_np = target.data.numpy()

        for i in range(len(hyp_transcriptions)):
            ref_transcription = form_target_transcription(
                target_np[cur_target_offset:(cur_target_offset + target_widths.data[i])], model.alphabet)
            cur_target_offset += target_widths.data[i]

            hyp_output.append((metadata['utt-ids'][i], hyp_transcriptions[i]))

            if have_lm:
                hyp_lm_output.append((metadata['utt-ids'][i], hyp_transcriptions_lm[i]))

            ref_output.append((metadata['utt-ids'][i], ref_transcription))


    hyp_out_file = os.path.join(args.outdir, "hyp-chars.txt")
    ref_out_file = os.path.join(args.outdir, "ref-chars.txt")

    if have_lm:
        hyp_lm_out_file = os.path.join(args.outdir, "hyp-lm-chars.txt")

    print("")
    print("Done. Now writing output files:")
    print("\t%s" % hyp_out_file)

    if have_lm:
        print("\t%s" % hyp_lm_out_file)

    print("\t%s" % ref_out_file)

    with open(hyp_out_file, 'w') as fh:
        for uttid, hyp in hyp_output:
            fh.write("%s (%s)\n" % (hyp, uttid))


    if have_lm:
        with open(hyp_lm_out_file, 'w') as fh:
            for uttid, hyp in hyp_lm_output:
                fh.write("%s (%s)\n" % (hyp, uttid))

    with open(ref_out_file, 'w') as fh:
        for uttid, ref in ref_output:
            fh.write("%s (%s)\n" % (ref, uttid))
Ejemplo n.º 10
0
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = []

    #if args.num_in_channels == 3:
    #    line_img_transforms.append(imagetransforms.ConvertColor())

    # Always convert color for the augmentations to work (for now)
    # Then alter convert back to grayscale if needed
    line_img_transforms.append(imagetransforms.ConvertColor())

    # Data augmentations (during training only)
    if args.daves_augment:
        line_img_transforms.append(daves_augment.ImageAug())

    if args.synth_input:

        # Randomly rotate image from -2 degrees to +2 degrees
        line_img_transforms.append(
            imagetransforms.Randomize(0.3, imagetransforms.RotateRandom(-2,
                                                                        2)))

        # Choose one of methods to blur/pixel-ify image  (or don't and choose identity)
        line_img_transforms.append(
            imagetransforms.PickOne([
                imagetransforms.TessBlockConv(kernel_val=1, bias_val=1),
                imagetransforms.TessBlockConv(rand=True),
                imagetransforms.Identity(),
            ]))

        aug_cn = iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5)
        line_img_transforms.append(
            imagetransforms.Randomize(0.5, lambda x: aug_cn.augment_image(x)))

        # With some probability, choose one of:
        #   Grayscale:  convert to grayscale and add back into color-image with random alpha
        #   Emboss:  Emboss image with random strength
        #   Invert:  Invert colors of image per-channel
        aug_gray = iaa.Grayscale(alpha=(0.0, 1.0))
        aug_emboss = iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0))
        aug_invert = iaa.Invert(1, per_channel=True)
        aug_invert2 = iaa.Invert(0.1, per_channel=False)
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.PickOne([
                    lambda x: aug_gray.augment_image(x),
                    lambda x: aug_emboss.augment_image(x),
                    lambda x: aug_invert.augment_image(x),
                    lambda x: aug_invert2.augment_image(x)
                ])))

        # Randomly try to crop close to top/bottom and left/right of lines
        # For now we are just guessing (up to 5% of ends and up to 10% of tops/bottoms chopped off)

        if args.tight_crop:
            # To make sure padding is reasonably consistent, we first rsize image to target line height
            # Then add padding to this version of image
            # Below it will get resized again to target line height
            line_img_transforms.append(
                imagetransforms.Randomize(
                    0.9,
                    imagetransforms.Compose([
                        imagetransforms.Scale(new_h=args.line_height),
                        imagetransforms.PadRandom(pxl_max_horizontal=30,
                                                  pxl_max_vertical=10)
                    ])))

        else:
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropHorizontal(.05)))
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropVertical(.1)))

        #line_img_transforms.append(imagetransforms.Randomize(0.2,
        #                                                     imagetransforms.PickOne([imagetransforms.MorphErode(3), imagetransforms.MorphDilate(3)])
        #                                                     ))

    # Make sure to do resize after degrade step above
    line_img_transforms.append(imagetransforms.Scale(new_h=args.line_height))

    if args.cvtGray:
        line_img_transforms.append(imagetransforms.ConvertGray())

    # Only do for grayscale
    if args.num_in_channels == 1:
        line_img_transforms.append(imagetransforms.InvertBlackWhite())

    if args.stripe:
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.AddRandomStripe(val=0,
                                                strip_width_from=1,
                                                strip_width_to=4)))

    line_img_transforms.append(imagetransforms.ToTensor())

    line_img_transforms = imagetransforms.Compose(line_img_transforms)

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    if len(args.datadir) == 1:
        train_dataset = OcrDataset(args.datadir[0], "train",
                                   line_img_transforms)
        validation_dataset = OcrDataset(args.datadir[0], "validation",
                                        line_img_transforms)
    else:
        train_dataset = OcrDatasetUnion(args.datadir, "train",
                                        line_img_transforms)
        validation_dataset = OcrDatasetUnion(args.datadir, "validation",
                                             line_img_transforms)

    if args.test_datadir is not None:
        if args.test_outdir is None:
            print(
                "Error, must specify both --test-datadir and --test-outdir together"
            )
            sys.exit(1)

        if not os.path.exists(args.test_outdir):
            os.makedirs(args.test_outdir)

        line_img_transforms_test = imagetransforms.Compose([
            imagetransforms.Scale(new_h=args.line_height),
            imagetransforms.ToTensor()
        ])
        test_dataset = OcrDataset(args.test_datadir, "test",
                                  line_img_transforms_test)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
        print(
            "Overriding automatically learned alphabet with pre-saved model alphabet"
        )
        if len(args.datadir) == 1:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
        else:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
            for ds in train_dataset.datasets:
                ds.alphabet = model.alphabet
            for ds in validation_dataset.datasets:
                ds.alphabet = model.alphabet

    else:
        model = CnnOcrModel(num_in_channels=args.num_in_channels,
                            input_line_height=args.line_height,
                            rds_line_height=args.rds_line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Setting dataloader after we have a chnae to (maybe!) over-ride the dataset alphabet from a pre-trained model
    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    if args.max_val_size > 0:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset,
                                               max_items=args.max_val_size,
                                               fixed_rand=True),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)
    else:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset, rand=False),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)

    if args.test_datadir is not None:
        test_dataloader = DataLoader(test_dataset,
                                     args.batch_size,
                                     num_workers=0,
                                     sampler=GroupedSampler(test_dataset,
                                                            rand=False),
                                     collate_fn=SortByWidthCollater,
                                     pin_memory=False,
                                     drop_last=False)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr_alpha,
                                 weight_decay=args.weight_decay)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    do_test_write = False
    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Only turn on test-on-testset when cer is starting to get non-random
            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)

                if val_cer < 0.5:
                    do_test_write = True

                if args.test_datadir is not None and (
                        iteration % snapshot_every_n_iterations
                        == 0) and do_test_write:
                    out_hyp_outdomain_file = os.path.join(
                        args.test_outdir,
                        "hyp-%07d.outdomain.utf8" % iteration)
                    out_hyp_indomain_file = os.path.join(
                        args.test_outdir, "hyp-%07d.indomain.utf8" % iteration)
                    out_meta_file = os.path.join(args.test_outdir,
                                                 "hyp-%07d.meta" % iteration)
                    test_on_val_writeout(test_dataloader, model,
                                         out_hyp_outdomain_file)
                    test_on_val_writeout(validation_dataloader, model,
                                         out_hyp_indomain_file)
                    with open(out_meta_file, 'w') as fh_out:
                        fh_out.write("%d,%f,%f,%f\n" %
                                     (iteration, val_cer, val_wer, val_loss))

                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'rtl': args.rtl,
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")