Example #1
0
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = imagetransforms.Compose([
        imagetransforms.Scale(new_h=args.line_height),
        imagetransforms.InvertBlackWhite(),
        imagetransforms.ToTensor(),
    ])

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    train_dataset = OcrDataset(args.datadir, "train", line_img_transforms)
    validation_dataset = OcrDataset(args.datadir, "validation",
                                    line_img_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    validation_dataloader = DataLoader(validation_dataset,
                                       args.batch_size,
                                       num_workers=0,
                                       sampler=GroupedSampler(
                                           validation_dataset, rand=False),
                                       collate_fn=SortByWidthCollater,
                                       pin_memory=False,
                                       drop_last=False)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
    else:
        model = CnnOcrModel(num_in_channels=1,
                            input_line_height=args.line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Do something with loss, running average, plot to some backend server, etc

            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)
                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")
Example #2
0
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = []

    #if args.num_in_channels == 3:
    #    line_img_transforms.append(imagetransforms.ConvertColor())

    # Always convert color for the augmentations to work (for now)
    # Then alter convert back to grayscale if needed
    line_img_transforms.append(imagetransforms.ConvertColor())

    # Data augmentations (during training only)
    if args.daves_augment:
        line_img_transforms.append(daves_augment.ImageAug())

    if args.synth_input:

        # Randomly rotate image from -2 degrees to +2 degrees
        line_img_transforms.append(
            imagetransforms.Randomize(0.3, imagetransforms.RotateRandom(-2,
                                                                        2)))

        # Choose one of methods to blur/pixel-ify image  (or don't and choose identity)
        line_img_transforms.append(
            imagetransforms.PickOne([
                imagetransforms.TessBlockConv(kernel_val=1, bias_val=1),
                imagetransforms.TessBlockConv(rand=True),
                imagetransforms.Identity(),
            ]))

        aug_cn = iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5)
        line_img_transforms.append(
            imagetransforms.Randomize(0.5, lambda x: aug_cn.augment_image(x)))

        # With some probability, choose one of:
        #   Grayscale:  convert to grayscale and add back into color-image with random alpha
        #   Emboss:  Emboss image with random strength
        #   Invert:  Invert colors of image per-channel
        aug_gray = iaa.Grayscale(alpha=(0.0, 1.0))
        aug_emboss = iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0))
        aug_invert = iaa.Invert(1, per_channel=True)
        aug_invert2 = iaa.Invert(0.1, per_channel=False)
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.PickOne([
                    lambda x: aug_gray.augment_image(x),
                    lambda x: aug_emboss.augment_image(x),
                    lambda x: aug_invert.augment_image(x),
                    lambda x: aug_invert2.augment_image(x)
                ])))

        # Randomly try to crop close to top/bottom and left/right of lines
        # For now we are just guessing (up to 5% of ends and up to 10% of tops/bottoms chopped off)

        if args.tight_crop:
            # To make sure padding is reasonably consistent, we first rsize image to target line height
            # Then add padding to this version of image
            # Below it will get resized again to target line height
            line_img_transforms.append(
                imagetransforms.Randomize(
                    0.9,
                    imagetransforms.Compose([
                        imagetransforms.Scale(new_h=args.line_height),
                        imagetransforms.PadRandom(pxl_max_horizontal=30,
                                                  pxl_max_vertical=10)
                    ])))

        else:
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropHorizontal(.05)))
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropVertical(.1)))

        #line_img_transforms.append(imagetransforms.Randomize(0.2,
        #                                                     imagetransforms.PickOne([imagetransforms.MorphErode(3), imagetransforms.MorphDilate(3)])
        #                                                     ))

    # Make sure to do resize after degrade step above
    line_img_transforms.append(imagetransforms.Scale(new_h=args.line_height))

    if args.cvtGray:
        line_img_transforms.append(imagetransforms.ConvertGray())

    # Only do for grayscale
    if args.num_in_channels == 1:
        line_img_transforms.append(imagetransforms.InvertBlackWhite())

    if args.stripe:
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.AddRandomStripe(val=0,
                                                strip_width_from=1,
                                                strip_width_to=4)))

    line_img_transforms.append(imagetransforms.ToTensor())

    line_img_transforms = imagetransforms.Compose(line_img_transforms)

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    if len(args.datadir) == 1:
        train_dataset = OcrDataset(args.datadir[0], "train",
                                   line_img_transforms)
        validation_dataset = OcrDataset(args.datadir[0], "validation",
                                        line_img_transforms)
    else:
        train_dataset = OcrDatasetUnion(args.datadir, "train",
                                        line_img_transforms)
        validation_dataset = OcrDatasetUnion(args.datadir, "validation",
                                             line_img_transforms)

    if args.test_datadir is not None:
        if args.test_outdir is None:
            print(
                "Error, must specify both --test-datadir and --test-outdir together"
            )
            sys.exit(1)

        if not os.path.exists(args.test_outdir):
            os.makedirs(args.test_outdir)

        line_img_transforms_test = imagetransforms.Compose([
            imagetransforms.Scale(new_h=args.line_height),
            imagetransforms.ToTensor()
        ])
        test_dataset = OcrDataset(args.test_datadir, "test",
                                  line_img_transforms_test)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
        print(
            "Overriding automatically learned alphabet with pre-saved model alphabet"
        )
        if len(args.datadir) == 1:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
        else:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
            for ds in train_dataset.datasets:
                ds.alphabet = model.alphabet
            for ds in validation_dataset.datasets:
                ds.alphabet = model.alphabet

    else:
        model = CnnOcrModel(num_in_channels=args.num_in_channels,
                            input_line_height=args.line_height,
                            rds_line_height=args.rds_line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Setting dataloader after we have a chnae to (maybe!) over-ride the dataset alphabet from a pre-trained model
    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    if args.max_val_size > 0:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset,
                                               max_items=args.max_val_size,
                                               fixed_rand=True),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)
    else:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset, rand=False),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)

    if args.test_datadir is not None:
        test_dataloader = DataLoader(test_dataset,
                                     args.batch_size,
                                     num_workers=0,
                                     sampler=GroupedSampler(test_dataset,
                                                            rand=False),
                                     collate_fn=SortByWidthCollater,
                                     pin_memory=False,
                                     drop_last=False)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr_alpha,
                                 weight_decay=args.weight_decay)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    do_test_write = False
    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Only turn on test-on-testset when cer is starting to get non-random
            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)

                if val_cer < 0.5:
                    do_test_write = True

                if args.test_datadir is not None and (
                        iteration % snapshot_every_n_iterations
                        == 0) and do_test_write:
                    out_hyp_outdomain_file = os.path.join(
                        args.test_outdir,
                        "hyp-%07d.outdomain.utf8" % iteration)
                    out_hyp_indomain_file = os.path.join(
                        args.test_outdir, "hyp-%07d.indomain.utf8" % iteration)
                    out_meta_file = os.path.join(args.test_outdir,
                                                 "hyp-%07d.meta" % iteration)
                    test_on_val_writeout(test_dataloader, model,
                                         out_hyp_outdomain_file)
                    test_on_val_writeout(validation_dataloader, model,
                                         out_hyp_indomain_file)
                    with open(out_meta_file, 'w') as fh_out:
                        fh_out.write("%d,%f,%f,%f\n" %
                                     (iteration, val_cer, val_wer, val_loss))

                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'rtl': args.rtl,
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")