Exemple #1
0
def train_model(args):
    print(args)
    print("generating config")
    config = Config(
        input_dim=args.input_dim,
        dropout=args.dropout,
        highway=args.highway,
        nn_layers=args.nn_layers,
    )
    model_name = ".".join(
        (args.model_file, str(args.rl_baseline_method), args.sampling_method,
         "gamma", str(args.gamma), "beta", str(args.beta), "batch",
         str(args.train_batch),
         "learning_rate", str(args.lr) + "-" + str(args.lr_sch), "bsz",
         str(args.batch_size), "data", args.data_dir.split('/')[0],
         args.eval_data, "input_dim", str(config.input_dim), "max_num",
         str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout",
         str(args.dropout) + "-" + str(args.clip_grad), "highway",
         str(args.highway), "nn-" + str(args.nn_layers), 'ans'))

    log_name = ".".join(
        ("log_bert/model", str(args.rl_baseline_method), args.sampling_method,
         "gamma", str(args.gamma), "beta", str(args.beta), "batch",
         str(args.train_batch), "lr", str(args.lr) + "-" + str(args.lr_sch),
         "bsz", str(args.batch_size), "data", args.data_dir.split('/')[0],
         args.eval_data, "input_dim", str(config.input_dim), "max_num",
         str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout",
         str(args.dropout) + "-" + str(args.clip_grad), "highway",
         str(args.highway), "nn-" + str(args.nn_layers), 'ans'))

    print("initialising data loader and RL learner")
    data_loader = PickleReader(args.data_dir)
    data = args.data_dir.split('/')[0]
    num_data = 0
    if data == "wiki_qa":
        num_data = 873
    elif data == "trec_qa":
        num_data = 1229
    else:
        assert (1 == 2)
    # init statistics
    reward_list = []
    loss_list = []
    best_eval_reward = 0.
    model_save_name = model_name

    bandit = ContextualBandit(b=args.batch_size,
                              rl_baseline_method=args.rl_baseline_method,
                              sample_method=args.sampling_method)

    print("Loaded the Bandit")

    bert_cb = model2.BERT_CB(config)

    print("Loaded the model")

    bert_cb.cuda()
    vocab = "vocab"

    if args.load_ext:
        model_name = args.model_file
        print("loading existing model%s" % model_name)
        bert_cb = torch.load(model_name,
                             map_location=lambda storage, loc: storage)
        bert_cb.cuda()
        model_save_name = model_name
        log_name = "/".join(("log_bert", model_name.split("/")[1]))
        print("finish loading and evaluate model %s" % model_name)
        # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test")
        best_eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args,
                                                   args.eval_data)[0]
    logging.basicConfig(filename='%s.log' % log_name,
                        level=logging.DEBUG,
                        format='%(asctime)s %(levelname)-10s %(message)s')
    # Loss and Optimizer
    optimizer_ans = torch.optim.Adam([
        param for param in bert_cb.parameters() if param.requires_grad == True
    ],
                                     lr=args.lr,
                                     betas=(args.beta, 0.999),
                                     weight_decay=1e-6)
    if args.lr_sch == 1:
        scheduler = ReduceLROnPlateau(optimizer_ans,
                                      'max',
                                      verbose=1,
                                      factor=0.9,
                                      patience=3,
                                      cooldown=3,
                                      min_lr=9e-5,
                                      epsilon=1e-6)
        if best_eval_reward:
            scheduler.step(best_eval_reward, 0)
            print("init_scheduler")
    elif args.lr_sch == 2:
        scheduler = torch.optim.lr_scheduler.CyclicLR(
            optimizer_ans,
            args.lr,
            args.lr_2,
            step_size_up=3 * int(num_data / args.train_batch),
            step_size_down=3 * int(num_data / args.train_batch),
            mode='exp_range',
            gamma=0.98,
            cycle_momentum=False)
    print("starting training")
    start_time = time.time()
    n_step = 100
    gamma = args.gamma
    #vocab = "vocab"
    if num_data < 2000:

        n_val = int(num_data / (5 * args.train_batch))
    else:
        n_val = int(num_data / (7 * args.train_batch))
    with torch.autograd.set_detect_anomaly(True):
        for epoch in tqdm(range(args.epochs_ext), desc="epoch:"):
            train_iter = data_loader.chunked_data_reader(
                "train", data_quota=args.train_example_quota)  #-1
            step_in_epoch = 0
            for dataset in train_iter:
                for step, contexts in tqdm(
                        enumerate(
                            BatchDataLoader(dataset,
                                            batch_size=args.train_batch,
                                            shuffle=True))):
                    try:
                        bert_cb.train()
                        step_in_epoch += 1
                        loss = 0.
                        reward = 0.
                        for context in contexts:

                            # q_a = torch.autograd.Variable(torch.from_numpy(context.features)).cuda()
                            pre_processed, a_len, sorted_id = model2.bert_preprocess(
                                context.answers)
                            q_a = torch.autograd.Variable(
                                pre_processed.type(torch.float))
                            a_len = torch.autograd.Variable(a_len)

                            outputs = bert_cb(q_a, a_len)
                            context.labels = np.array(
                                context.labels)[sorted_id]

                            if args.prt_inf and np.random.randint(0, 100) == 0:
                                prt = True
                            else:
                                prt = False

                            loss_t, reward_t = bandit.train(
                                outputs,
                                context,
                                max_num_of_ans=args.max_num_of_ans,
                                reward_type=args.reward_type,
                                prt=prt)
                            #print(str(loss_t)+' '+str(len(a_len)))

                            #    loss_t = loss_t.view(-1)
                            true_labels = np.zeros(len(context.labels))
                            gold_labels = np.array(context.labels)
                            true_labels[gold_labels > 0] = 1.0
                            # ml_loss = F.binary_cross_entropy(outputs.view(-1),torch.tensor(true_labels).type(torch.float).cuda())
                            ml_loss = F.binary_cross_entropy(
                                outputs.view(-1),
                                torch.tensor(true_labels).type(
                                    torch.float).cuda())

                            loss_e = ((gamma * loss_t) +
                                      ((1 - gamma) * ml_loss))
                            loss_e.backward()
                            loss += loss_e.item()
                            reward += reward_t
                        loss = loss / args.train_batch
                        reward = reward / args.train_batch
                        if prt:
                            print('Probabilities: ',
                                  outputs.squeeze().data.cpu().numpy())
                            print('-' * 80)

                        reward_list.append(reward)
                        loss_list.append(loss)
                        #if isinstance(loss, Variable):
                        #    loss.backward()

                        if step % 1 == 0:
                            if args.clip_grad:
                                torch.nn.utils.clip_grad_norm_(
                                    bert_cb.parameters(),
                                    args.clip_grad)  # gradient clipping
                            optimizer_ans.step()
                            optimizer_ans.zero_grad()
                        if args.lr_sch == 2:
                            scheduler.step()
                        logging.info('Epoch %d Step %d Reward %.4f Loss %.4f' %
                                     (epoch, step_in_epoch, reward, loss))
                    except Exception as e:
                        print(e)
                        #print(loss)
                        #print(loss_e)
                        traceback.print_exc()

                    if (step_in_epoch) % n_step == 0 and step_in_epoch != 0:
                        logging.info('Epoch ' + str(epoch) + ' Step ' +
                                     str(step_in_epoch) + ' reward: ' +
                                     str(np.mean(reward_list)) + ' loss: ' +
                                     str(np.mean(loss_list)))
                        reward_list = []
                        loss_list = []

                    if (step_in_epoch) % n_val == 0 and step_in_epoch != 0:
                        print("doing evaluation")
                        bert_cb.eval()
                        eval_reward = evaluate.ext_model_eval(
                            bert_cb, vocab, args, args.eval_data)

                        if eval_reward[0] > best_eval_reward:
                            best_eval_reward = eval_reward[0]
                            print(
                                "saving model %s with eval_reward:" %
                                model_save_name, eval_reward)
                            logging.debug("saving model" +
                                          str(model_save_name) +
                                          "with eval_reward:" +
                                          str(eval_reward))
                            torch.save(bert_cb, model_name)
                        print('epoch ' + str(epoch) +
                              ' reward in validation: ' + str(eval_reward))
                        logging.debug('epoch ' + str(epoch) +
                                      ' reward in validation: ' +
                                      str(eval_reward))
                        logging.debug('time elapsed:' +
                                      str(time.time() - start_time))
            if args.lr_sch == 1:
                bert_cb.eval()
                eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args,
                                                      args.eval_data)
                scheduler.step(eval_reward[0], epoch)
    return bert_cb
Exemple #2
0
def main():
    args = parser.parse_args()

    if args.output:
        output_base = args.output
    else:
        output_base = './output'
    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"),
        args.model,
        args.gp,
        'f'+str(args.fold)])
    output_dir = get_outdir(output_base, 'train', exp_name)

    train_input_root = os.path.join(args.data)
    batch_size = args.batch_size
    num_epochs = args.epochs
    wav_size = (16000,)
    num_classes = len(dataset.get_labels())

    torch.manual_seed(args.seed)

    model = model_factory.create_model(
        args.model,
        in_chs=1,
        pretrained=args.pretrained,
        num_classes=num_classes,
        drop_rate=args.drop,
        global_pool=args.gp,
        checkpoint_path=args.initial_checkpoint)
    #model.reset_classifier(num_classes=num_classes)

    dataset_train = dataset.CommandsDataset(
        root=train_input_root,
        mode='train',
        fold=args.fold,
        wav_size=wav_size,
        format='spectrogram',
    )

    loader_train = data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=args.workers
    )

    dataset_eval = dataset.CommandsDataset(
        root=train_input_root,
        mode='validate',
        fold=args.fold,
        wav_size=wav_size,
        format='spectrogram',
    )

    loader_eval = data.DataLoader(
        dataset_eval,
        batch_size=args.batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=args.workers
    )

    train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.cuda()
    validate_loss_fn = validate_loss_fn.cuda()

    opt_params = list(model.parameters())
    if args.opt.lower() == 'sgd':
        optimizer = optim.SGD(
            opt_params, lr=args.lr,
            momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    elif args.opt.lower() == 'adam':
        optimizer = optim.Adam(
            opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
    elif args.opt.lower() == 'nadam':
        optimizer = nadam.Nadam(
            opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
    elif args.opt.lower() == 'adadelta':
        optimizer = optim.Adadelta(
            opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
    elif args.opt.lower() == 'rmsprop':
        optimizer = optim.RMSprop(
            opt_params, lr=args.lr, alpha=0.9, eps=args.opt_eps,
            momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        assert False and "Invalid optimizer"
    del opt_params

    if not args.decay_epochs:
        print('No decay epoch set, using plateau scheduler.')
        lr_scheduler = ReduceLROnPlateau(optimizer, patience=10)
    else:
        lr_scheduler = None

    # optionally resume from a checkpoint
    start_epoch = 0 if args.start_epoch is None else args.start_epoch
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                if 'args' in checkpoint:
                    print(checkpoint['args'])
                new_state_dict = OrderedDict()
                for k, v in checkpoint['state_dict'].items():
                    if k.startswith('module'):
                        name = k[7:] # remove `module.`
                    else:
                        name = k
                    new_state_dict[name] = v
                model.load_state_dict(new_state_dict)
                if 'optimizer' in checkpoint:
                    optimizer.load_state_dict(checkpoint['optimizer'])
                if 'loss' in checkpoint:
                    train_loss_fn.load_state_dict(checkpoint['loss'])
                print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
                start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch
            else:
                model.load_state_dict(checkpoint)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit(1)

    saver = CheckpointSaver(checkpoint_dir=output_dir)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    # Optional fine-tune of only the final classifier weights for specified number of epochs (or part of)
    if not args.resume and args.ft_epochs > 0.:
        if isinstance(model, torch.nn.DataParallel):
            classifier_params = model.module.get_classifier().parameters()
        else:
            classifier_params = model.get_classifier().parameters()
        if args.opt.lower() == 'adam':
            finetune_optimizer = optim.Adam(
                classifier_params,
                lr=args.ft_lr, weight_decay=args.weight_decay)
        else:
            finetune_optimizer = optim.SGD(
                classifier_params,
                lr=args.ft_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)

        finetune_epochs_int = int(np.ceil(args.ft_epochs))
        finetune_final_batches = int(np.ceil((1 - (finetune_epochs_int - args.ft_epochs)) * len(loader_train)))
        print(finetune_epochs_int, finetune_final_batches)
        for fepoch in range(0, finetune_epochs_int):
            if fepoch == finetune_epochs_int - 1 and finetune_final_batches:
                batch_limit = finetune_final_batches
            else:
                batch_limit = 0
            train_epoch(
                fepoch, model, loader_train, finetune_optimizer, train_loss_fn, args,
                output_dir=output_dir, batch_limit=batch_limit)

    best_loss = None
    try:
        for epoch in range(start_epoch, num_epochs):
            if args.decay_epochs:
                adjust_learning_rate(
                    optimizer, epoch, initial_lr=args.lr,
                    decay_rate=args.decay_rate, decay_epochs=args.decay_epochs)

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
                saver=saver, output_dir=output_dir)

            # save a recovery in case validation blows up
            saver.save_recovery({
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss': train_loss_fn.state_dict(),
                'args': args,
                'gp': args.gp,
                },
                epoch=epoch + 1,
                batch_idx=0)

            step = epoch * len(loader_train)
            eval_metrics = validate(
                step, model, loader_eval, validate_loss_fn, args,
                output_dir=output_dir)

            if lr_scheduler is not None:
                lr_scheduler.step(eval_metrics['eval_loss'])

            rowd = OrderedDict(epoch=epoch)
            rowd.update(train_metrics)
            rowd.update(eval_metrics)
            with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
                dw = csv.DictWriter(cf, fieldnames=rowd.keys())
                if best_loss is None:  # first iteration (epoch == 1 can't be used)
                    dw.writeheader()
                dw.writerow(rowd)

            # save proper checkpoint with eval metric
            best_loss = saver.save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': args,
                'gp': args.gp,
                },
                epoch=epoch + 1,
                metric=eval_metrics['eval_loss'])

    except KeyboardInterrupt:
        pass
    print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
Exemple #3
0
        # forward + backward + optimize
        outputs = net(inputs)

        #labels += 10
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        #n_image_total += labels.size()[0]
        # print statistics
        running_loss += loss.data[0]
        if n_image_total % 2000 == 1999:    # print every 2000 mini-batches
        #if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            is_best_changed, is_lr_decayed = scheduler.step(running_loss / 2000, n_image_total) # update lr if needed
            if is_lr_just_decayed and (not is_best_changed):
                shall_stop = True
                break
            is_lr_just_decayed = is_lr_decayed
            running_loss = 0.0

        n_image_total += 1
    if shall_stop:
        break

print('Finished Training')

dataiter = iter(testloader)
images, labels = dataiter.next()
Exemple #4
0
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = imagetransforms.Compose([
        imagetransforms.Scale(new_h=args.line_height),
        imagetransforms.InvertBlackWhite(),
        imagetransforms.ToTensor(),
    ])

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    train_dataset = OcrDataset(args.datadir, "train", line_img_transforms)
    validation_dataset = OcrDataset(args.datadir, "validation",
                                    line_img_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    validation_dataloader = DataLoader(validation_dataset,
                                       args.batch_size,
                                       num_workers=0,
                                       sampler=GroupedSampler(
                                           validation_dataset, rand=False),
                                       collate_fn=SortByWidthCollater,
                                       pin_memory=False,
                                       drop_last=False)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
    else:
        model = CnnOcrModel(num_in_channels=1,
                            input_line_height=args.line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Do something with loss, running average, plot to some backend server, etc

            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)
                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = []

    #if args.num_in_channels == 3:
    #    line_img_transforms.append(imagetransforms.ConvertColor())

    # Always convert color for the augmentations to work (for now)
    # Then alter convert back to grayscale if needed
    line_img_transforms.append(imagetransforms.ConvertColor())

    # Data augmentations (during training only)
    if args.daves_augment:
        line_img_transforms.append(daves_augment.ImageAug())

    if args.synth_input:

        # Randomly rotate image from -2 degrees to +2 degrees
        line_img_transforms.append(
            imagetransforms.Randomize(0.3, imagetransforms.RotateRandom(-2,
                                                                        2)))

        # Choose one of methods to blur/pixel-ify image  (or don't and choose identity)
        line_img_transforms.append(
            imagetransforms.PickOne([
                imagetransforms.TessBlockConv(kernel_val=1, bias_val=1),
                imagetransforms.TessBlockConv(rand=True),
                imagetransforms.Identity(),
            ]))

        aug_cn = iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5)
        line_img_transforms.append(
            imagetransforms.Randomize(0.5, lambda x: aug_cn.augment_image(x)))

        # With some probability, choose one of:
        #   Grayscale:  convert to grayscale and add back into color-image with random alpha
        #   Emboss:  Emboss image with random strength
        #   Invert:  Invert colors of image per-channel
        aug_gray = iaa.Grayscale(alpha=(0.0, 1.0))
        aug_emboss = iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0))
        aug_invert = iaa.Invert(1, per_channel=True)
        aug_invert2 = iaa.Invert(0.1, per_channel=False)
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.PickOne([
                    lambda x: aug_gray.augment_image(x),
                    lambda x: aug_emboss.augment_image(x),
                    lambda x: aug_invert.augment_image(x),
                    lambda x: aug_invert2.augment_image(x)
                ])))

        # Randomly try to crop close to top/bottom and left/right of lines
        # For now we are just guessing (up to 5% of ends and up to 10% of tops/bottoms chopped off)

        if args.tight_crop:
            # To make sure padding is reasonably consistent, we first rsize image to target line height
            # Then add padding to this version of image
            # Below it will get resized again to target line height
            line_img_transforms.append(
                imagetransforms.Randomize(
                    0.9,
                    imagetransforms.Compose([
                        imagetransforms.Scale(new_h=args.line_height),
                        imagetransforms.PadRandom(pxl_max_horizontal=30,
                                                  pxl_max_vertical=10)
                    ])))

        else:
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropHorizontal(.05)))
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropVertical(.1)))

        #line_img_transforms.append(imagetransforms.Randomize(0.2,
        #                                                     imagetransforms.PickOne([imagetransforms.MorphErode(3), imagetransforms.MorphDilate(3)])
        #                                                     ))

    # Make sure to do resize after degrade step above
    line_img_transforms.append(imagetransforms.Scale(new_h=args.line_height))

    if args.cvtGray:
        line_img_transforms.append(imagetransforms.ConvertGray())

    # Only do for grayscale
    if args.num_in_channels == 1:
        line_img_transforms.append(imagetransforms.InvertBlackWhite())

    if args.stripe:
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.AddRandomStripe(val=0,
                                                strip_width_from=1,
                                                strip_width_to=4)))

    line_img_transforms.append(imagetransforms.ToTensor())

    line_img_transforms = imagetransforms.Compose(line_img_transforms)

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    if len(args.datadir) == 1:
        train_dataset = OcrDataset(args.datadir[0], "train",
                                   line_img_transforms)
        validation_dataset = OcrDataset(args.datadir[0], "validation",
                                        line_img_transforms)
    else:
        train_dataset = OcrDatasetUnion(args.datadir, "train",
                                        line_img_transforms)
        validation_dataset = OcrDatasetUnion(args.datadir, "validation",
                                             line_img_transforms)

    if args.test_datadir is not None:
        if args.test_outdir is None:
            print(
                "Error, must specify both --test-datadir and --test-outdir together"
            )
            sys.exit(1)

        if not os.path.exists(args.test_outdir):
            os.makedirs(args.test_outdir)

        line_img_transforms_test = imagetransforms.Compose([
            imagetransforms.Scale(new_h=args.line_height),
            imagetransforms.ToTensor()
        ])
        test_dataset = OcrDataset(args.test_datadir, "test",
                                  line_img_transforms_test)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
        print(
            "Overriding automatically learned alphabet with pre-saved model alphabet"
        )
        if len(args.datadir) == 1:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
        else:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
            for ds in train_dataset.datasets:
                ds.alphabet = model.alphabet
            for ds in validation_dataset.datasets:
                ds.alphabet = model.alphabet

    else:
        model = CnnOcrModel(num_in_channels=args.num_in_channels,
                            input_line_height=args.line_height,
                            rds_line_height=args.rds_line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Setting dataloader after we have a chnae to (maybe!) over-ride the dataset alphabet from a pre-trained model
    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    if args.max_val_size > 0:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset,
                                               max_items=args.max_val_size,
                                               fixed_rand=True),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)
    else:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset, rand=False),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)

    if args.test_datadir is not None:
        test_dataloader = DataLoader(test_dataset,
                                     args.batch_size,
                                     num_workers=0,
                                     sampler=GroupedSampler(test_dataset,
                                                            rand=False),
                                     collate_fn=SortByWidthCollater,
                                     pin_memory=False,
                                     drop_last=False)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr_alpha,
                                 weight_decay=args.weight_decay)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    do_test_write = False
    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Only turn on test-on-testset when cer is starting to get non-random
            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)

                if val_cer < 0.5:
                    do_test_write = True

                if args.test_datadir is not None and (
                        iteration % snapshot_every_n_iterations
                        == 0) and do_test_write:
                    out_hyp_outdomain_file = os.path.join(
                        args.test_outdir,
                        "hyp-%07d.outdomain.utf8" % iteration)
                    out_hyp_indomain_file = os.path.join(
                        args.test_outdir, "hyp-%07d.indomain.utf8" % iteration)
                    out_meta_file = os.path.join(args.test_outdir,
                                                 "hyp-%07d.meta" % iteration)
                    test_on_val_writeout(test_dataloader, model,
                                         out_hyp_outdomain_file)
                    test_on_val_writeout(validation_dataloader, model,
                                         out_hyp_indomain_file)
                    with open(out_meta_file, 'w') as fh_out:
                        fh_out.write("%d,%f,%f,%f\n" %
                                     (iteration, val_cer, val_wer, val_loss))

                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'rtl': args.rtl,
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")