def __init__(self, data_dir_list, split, transforms, alphabet=None):
        logger.info("Loading OCR Dataset Union: [%s] split from [%s]." % (split, data_dir_list))

        self.datasets = []
        self.nentries = 0

        self.lmdb_cache = dict()
        for data_dir in data_dir_list:
            if data_dir in self.lmdb_cache:
                dataset = OcrDataset(data_dir, split, transforms, alphabet, preloaded_lmdb=self.lmdb_cache[data_dir])
            else:
                dataset = OcrDataset(data_dir, split, transforms, alphabet)
                self.lmdb_cache[data_dir] = dataset.lmdb_env
            self.datasets.append(dataset)
            self.nentries += len(dataset)

        # Because different datasets might have different alphabets, we need to merge them and unify
        self.merge_alphabets()

        # Merge size group stuff (ugh)
        self.size_group_keys = self.datasets[0].size_group_keys
        self.size_groups = dict()
        for cur_limit in self.size_group_keys:
            self.size_groups[cur_limit] = []


        accumulatd_max_idx = 0
        for ds in self.datasets:
            # For now we only merge if szme set of size groups  (need to change this requirement!)
            assert ds.size_group_keys == self.size_group_keys
            for cur_limit in self.size_group_keys:
                self.size_groups[cur_limit].extend([accumulatd_max_idx + idx for idx in ds.size_groups[cur_limit]])
            accumulatd_max_idx += ds.max_index
def evaluate(args, model, tokenizer, labels, mode, prefix=""):
    eval_dataset = OcrDataset(args.data_dir, tokenizer, labels, "test_v2")
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset,
        sampler=eval_sampler,
        batch_size=args.eval_batch_size,
        collate_fn=None,
    )
    model.eval()

    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)

    if mode == "test_v2":
        dev_loss = 0
        dev_steps = 0
        logit_all, label_all, file_names_all = [], [], []
        logit_soft_all = []
        correct_files, error_files = [], []
        files = []
        entropy_all = []
        for batch in tqdm(eval_dataloader):
            with torch.no_grad():
                inputs = {
                    "input_ids": batch['input_ids'].to(args.device),
                    "attention_mask": batch['attention_mask'].to(args.device),
                    "image": batch['image'].to(args.device),
                    "label": batch['label'].to(args.device),
                }

                if args.model_type in ["layoutlm"]:
                    inputs["bbox"] = batch['bbox'].to(args.device)
                inputs["token_type_ids"] = (
                    batch['token_type_ids'].to(args.device)
                    if args.model_type in ["bert", "layoutlm"] else None
                )  # RoBERTa don"t use segment_ids

                logit, loss = model(**inputs)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training

                file_names = batch['file_names']

                logit_soft = F.softmax(logit, dim=-1)
                entropy = (-torch.sum(logit_soft * torch.log(logit_soft),
                                      -1)).detach().cpu().numpy().tolist()
                logit_soft = logit_soft.detach().cpu().numpy().tolist()
                #import ipdb;ipdb.set_trace()
                dev_loss += loss.item()

                logit = logit.argmax(-1).detach().cpu().numpy().tolist()
                label = inputs['label'].detach().cpu().numpy().tolist()

                logit_all.extend(logit)
                logit_soft_all.extend(logit_soft)
                label_all.extend(label)
                file_names_all.extend(file_names)
                entropy_all.extend(entropy)

            dev_steps += 1

        acc = accuracy_score(logit_all, label_all)
        f1 = f1_score(logit_all, label_all,
                      average=None)  # 'micro'/'macro'/'weighted'
        macro_f1 = f1_score(logit_all, label_all, average='macro')
        micro_f1 = f1_score(logit_all, label_all, average='micro')

        cm = confusion_matrix(logit_all, label_all)

        labels = open(args.labels).read().split('\n')
        label_map = {i: label for i, label in enumerate(labels)}

        for i, (logit, label) in enumerate(zip(logit_all, label_all)):

            files.append((file_names_all[i], label_map[logit],
                          label_map[label], entropy_all[i]))
            #if logit != label:
            #import ipdb;ipdb.set_trace()
            if logit != label:
                #                 print(logit_soft_all[i], max(logit_soft_all[i]))
                #                 print('====='*20)
                error_files.append(
                    (file_names_all[i], label_map[logit], label_map[label]))


#             if (label == 10 and np.argmax(logit_soft_all[i]) != 10 and max(logit_soft_all[i]) > 0.5) or \
#             (label != 10 and np.argmax(logit_soft_all[i]) == 10 and max(logit_soft_all[i]) <= 0.3):
#                 error_files.append((file_names_all[i], label_map[logit], label_map[label]))
            else:
                correct_files.append(
                    (file_names_all[i], label_map[logit], label_map[label]))
        #print("Test: acc {:.5f}, f1 {}, cm {}".format(acc, f1, cm))
        print("Test: acc {:.5f}, macro_f1 {}, micro_f1 {}".format(
            acc, macro_f1, micro_f1))
        print("Test: f1 {}".format(f1))
        print("Test: cm {}".format(cm))
        return dev_loss / dev_steps, error_files, correct_files, files
Beispiel #3
0
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = imagetransforms.Compose([
        imagetransforms.Scale(new_h=args.line_height),
        imagetransforms.InvertBlackWhite(),
        imagetransforms.ToTensor(),
    ])

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    train_dataset = OcrDataset(args.datadir, "train", line_img_transforms)
    validation_dataset = OcrDataset(args.datadir, "validation",
                                    line_img_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    validation_dataloader = DataLoader(validation_dataset,
                                       args.batch_size,
                                       num_workers=0,
                                       sampler=GroupedSampler(
                                           validation_dataset, rand=False),
                                       collate_fn=SortByWidthCollater,
                                       pin_memory=False,
                                       drop_last=False)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
    else:
        model = CnnOcrModel(num_in_channels=1,
                            input_line_height=args.line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Do something with loss, running average, plot to some backend server, etc

            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)
                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")
def main():
    args = get_args()

    model = CnnOcrModel.FromSavedWeights(args.model_path)
    model.eval()

    line_img_transforms = [
        imagetransforms.Scale(new_h=model.input_line_height)
    ]

    # Only do for grayscale
    if model.num_in_channels == 1:
        line_img_transforms.append(imagetransforms.InvertBlackWhite())

    # For right-to-left languages
    if model.rtl:
        line_img_transforms.append(imagetransforms.HorizontalFlip())

    line_img_transforms.append(imagetransforms.ToTensor())

    line_img_transforms = imagetransforms.Compose(line_img_transforms)

    have_lm = (args.lm_path is not None) and (args.lm_path != "")

    if have_lm:
        lm_units = os.path.join(args.lm_path, 'units.txt')
        lm_words = os.path.join(args.lm_path, 'words.txt')
        lm_wfst = os.path.join(args.lm_path, 'TLG.fst')

    test_dataset = OcrDataset(args.datadir, "test", line_img_transforms)

    # Set seed for consistancy
    torch.manual_seed(7)
    torch.cuda.manual_seed_all(7)

    if have_lm:
        model.init_lm(lm_wfst, lm_words, lm_units, acoustic_weight=0.8)

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_data_threads,
        sampler=GroupedSampler(test_dataset, rand=False),
        collate_fn=SortByWidthCollater,
        pin_memory=True,
        drop_last=False)

    hyp_output = []
    hyp_lm_output = []

    print("About to process test set. Total # iterations is %d." %
          len(test_dataloader))

    # No need for backprop during validation test
    with torch.no_grad():
        for idx, (input_tensor, target, input_widths, target_widths,
                  metadata) in enumerate(test_dataloader):
            sys.stdout.write(".")
            sys.stdout.flush()

            # Wrap inputs in PyTorch Variable class
            input_tensor = input_tensor.cuda(async=True)

            # Call model
            model_output, model_output_actual_lengths = model(
                input_tensor, input_widths)

            # Do LM-free decoding
            hyp_transcriptions = model.decode_without_lm(
                model_output, model_output_actual_lengths, uxxxx=True)

            # Optionally, do LM decoding
            if have_lm:
                hyp_transcriptions_lm = model.decode_with_lm(
                    model_output, model_output_actual_lengths, uxxxx=True)

            for i in range(len(hyp_transcriptions)):
                hyp_output.append(
                    (metadata['utt-ids'][i], hyp_transcriptions[i]))

                if have_lm:
                    hyp_lm_output.append(
                        (metadata['utt-ids'][i], hyp_transcriptions_lm[i]))

    hyp_out_file = os.path.join(args.outdir, "hyp-chars.txt")

    if have_lm:
        hyp_lm_out_file = os.path.join(args.outdir, "hyp-lm-chars.txt")

    print("")
    print("Done. Now writing output files:")
    print("\t%s" % hyp_out_file)

    if have_lm:
        print("\t%s" % hyp_lm_out_file)

    with open(hyp_out_file, 'w') as fh:
        for uttid, hyp in hyp_output:
            fh.write("%s (%s)\n" % (hyp, uttid))

    if have_lm:
        with open(hyp_lm_out_file, 'w') as fh:
            for uttid, hyp in hyp_lm_output:
                fh.write("%s (%s)\n" % (hyp, uttid))
Beispiel #5
0
def main():
    args = get_args()

    model = CnnOcrModel.FromSavedWeights(args.model_path)
    model.eval()

    line_img_transforms = []

    if args.cvtGray:
        line_img_transforms.append(imagetransforms.ConvertGray())

    line_img_transforms.append(
        imagetransforms.Scale(new_h=model.input_line_height))

    # Only do for grayscale
    if model.num_in_channels == 1:
        line_img_transforms.append(imagetransforms.InvertBlackWhite())

    # For right-to-left languages


#    if model.rtl:
#        line_img_transforms.append(imagetransforms.HorizontalFlip())

    line_img_transforms.append(imagetransforms.ToTensor())

    line_img_transforms = imagetransforms.Compose(line_img_transforms)

    test_dataset = OcrDataset(args.datadir,
                              "test",
                              line_img_transforms,
                              max_allowed_width=1e5)

    # Set seed for consistancy
    torch.manual_seed(7)
    torch.cuda.manual_seed_all(7)

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_data_threads,
        sampler=GroupedSampler(test_dataset, rand=False),
        collate_fn=SortByWidthCollater,
        pin_memory=True,
        drop_last=False)

    print("About to process test set. Total # iterations is %d." %
          len(test_dataloader))

    # Setup seperate process + queue for handling CPU-portion of decoding
    input_queue = multiprocessing.Queue()
    decoding_p = multiprocessing.Process(target=decode_thread,
                                         args=(input_queue, args.outdir,
                                               model.alphabet, args.lm_path))
    decoding_p.start()

    # No need for backprop during validation test
    start_time = datetime.datetime.now()
    with torch.no_grad():
        for idx, (input_tensor, target, input_widths, target_widths,
                  metadata) in enumerate(test_dataloader):
            # Wrap inputs in PyTorch Variable class
            input_tensor = input_tensor.cuda(async=True)

            # Call model
            model_output, model_output_actual_lengths = model(
                input_tensor, input_widths)

            # Put model output on the queue for background process to decode
            input_queue.put(
                (model_output.cpu(), model_output_actual_lengths, metadata))

    # Now we just need to wait for decode thread to finish
    input_queue.put(None)
    input_queue.close()
    decoding_p.join()

    end_time = datetime.datetime.now()

    print("Decoding took %f seconds" % (end_time - start_time).total_seconds())
Beispiel #6
0
def main():
    args = get_args()

    model = CnnOcrModel.FromSavedWeights(args.model_path)
    model.eval()

    line_img_transforms = imagetransforms.Compose([
        imagetransforms.Scale(new_h=model.input_line_height),
        imagetransforms.InvertBlackWhite(),
        imagetransforms.ToTensor(),
    ])


    have_lm = (args.lm_path is not None) and (args.lm_path != "")

    if have_lm:
        lm_units = os.path.join(args.lm_path, 'units.txt')
        lm_words = os.path.join(args.lm_path, 'words.txt')
        lm_wfst = os.path.join(args.lm_path, 'TLG.fst')


    test_dataset = OcrDataset(args.datadir, "test", line_img_transforms)

    # Set seed for consistancy
    torch.manual_seed(7)
    torch.cuda.manual_seed_all(7)


    if have_lm:
        model.init_lm(lm_wfst, lm_words, lm_units, acoustic_weight=0.8)


    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=args.batch_size,
                                                  num_workers=args.num_data_threads,
                                                  sampler=GroupedSampler(test_dataset, rand=False),
                                                  collate_fn=SortByWidthCollater,
                                                  pin_memory=True,
                                                  drop_last=False)


    hyp_output = []
    hyp_lm_output = []
    ref_output = []


    print("About to process test set. Total # iterations is %d." % len(test_dataloader))

    for idx, (input_tensor, target, input_widths, target_widths, metadata) in enumerate(test_dataloader):
        sys.stdout.write(".")
        sys.stdout.flush()

        # Wrap inputs in PyTorch Variable class
        input_tensor = Variable(input_tensor.cuda(async=True), volatile=True)
        target = Variable(target, volatile=True)
        target_widths = Variable(target_widths, volatile=True)
        input_widths = Variable(input_widths, volatile=True)

        # Call model
        model_output, model_output_actual_lengths = model(input_tensor, input_widths)

        # Do LM-free decoding
        hyp_transcriptions = model.decode_without_lm(model_output, model_output_actual_lengths, uxxxx=True)

        # Optionally, do LM decoding
        if have_lm:
            hyp_transcriptions_lm = model.decode_with_lm(model_output, model_output_actual_lengths, uxxxx=True)



        cur_target_offset = 0
        target_np = target.data.numpy()

        for i in range(len(hyp_transcriptions)):
            ref_transcription = form_target_transcription(
                target_np[cur_target_offset:(cur_target_offset + target_widths.data[i])], model.alphabet)
            cur_target_offset += target_widths.data[i]

            hyp_output.append((metadata['utt-ids'][i], hyp_transcriptions[i]))

            if have_lm:
                hyp_lm_output.append((metadata['utt-ids'][i], hyp_transcriptions_lm[i]))

            ref_output.append((metadata['utt-ids'][i], ref_transcription))


    hyp_out_file = os.path.join(args.outdir, "hyp-chars.txt")
    ref_out_file = os.path.join(args.outdir, "ref-chars.txt")

    if have_lm:
        hyp_lm_out_file = os.path.join(args.outdir, "hyp-lm-chars.txt")

    print("")
    print("Done. Now writing output files:")
    print("\t%s" % hyp_out_file)

    if have_lm:
        print("\t%s" % hyp_lm_out_file)

    print("\t%s" % ref_out_file)

    with open(hyp_out_file, 'w') as fh:
        for uttid, hyp in hyp_output:
            fh.write("%s (%s)\n" % (hyp, uttid))


    if have_lm:
        with open(hyp_lm_out_file, 'w') as fh:
            for uttid, hyp in hyp_lm_output:
                fh.write("%s (%s)\n" % (hyp, uttid))

    with open(ref_out_file, 'w') as fh:
        for uttid, ref in ref_output:
            fh.write("%s (%s)\n" % (ref, uttid))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
    )
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " +
        ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--img_model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " +
        ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written.",
    )

    ## Other parameters
    parser.add_argument(
        "--labels",
        default="./CORD/labels.txt",
        type=str,
        help=
        "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.",
    )
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument("--do_train",
                        action="store_true",
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action="store_true",
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_predict",
        action="store_true",
        help="Whether to run predictions on the test set.",
    )
    parser.add_argument(
        "--evaluate_during_training",
        action="store_true",
        help="Whether to run evaluation during training at each logging step.",
    )
    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="Set this flag if you are using an uncased model.",
    )

    parser.add_argument(
        "--per_gpu_train_batch_size",
        default=8,
        type=int,
        help="Batch size per GPU/CPU for training.",
    )
    parser.add_argument(
        "--per_gpu_eval_batch_size",
        default=16,
        type=int,
        help="Batch size per GPU/CPU for evaluation.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--learning_rate",
        default=5e-5,  # 5e-5
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs",
        default=3.0,
        type=float,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")

    parser.add_argument("--logging_steps",
                        type=int,
                        default=500,
                        help="Log every X updates steps.")
    parser.add_argument(
        "--save_steps",
        type=int,
        default=50,
        help="Save checkpoint every X updates steps.",
    )
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Avoid using CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        help="Overwrite the content of the output directory",
    )
    parser.add_argument(
        "--overwrite_cache",
        action="store_true",
        help="Overwrite the cached training and evaluation sets",
    )
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="For distributed training: local_rank",
    )
    parser.add_argument("--server_ip",
                        type=str,
                        default="",
                        help="For distant debugging.")
    parser.add_argument("--server_port",
                        type=str,
                        default="",
                        help="For distant debugging.")
    args = parser.parse_args()

    if (os.path.exists(args.output_dir) and os.listdir(args.output_dir)
            and args.do_train):
        if not args.overwrite_output_dir:
            raise ValueError(
                "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
                .format(args.output_dir))
        else:
            if args.local_rank in [-1, 0]:
                shutil.rmtree(args.output_dir)

    if not os.path.exists(args.output_dir) and (args.do_eval
                                                or args.do_predict):
        raise ValueError(
            "Output directory ({}) does not exist. Please train and save the model before inference stage."
            .format(args.output_dir))

    if (not os.path.exists(args.output_dir) and args.do_train
            and args.local_rank in [-1, 0]):
        os.makedirs(args.output_dir)

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        filename=os.path.join(args.output_dir, "train.log")
        if args.local_rank in [-1, 0] else None,
        format="%(asctime)s - %(levelname)s - %(name)s -  %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)

    labels = open(args.labels).read().split('\n')
    num_labels = len(labels)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    model = DocumentClassifier(config, args.img_model_type, args.device, 2816,
                               num_labels)  # 1280/1536/3072

    pretrained_weights = torch.load(os.path.join(args.model_name_or_path,
                                                 'pytorch_model.bin'),
                                    map_location=torch.device('cpu'))

    pretrained_weights_new = OrderedDict()
    for key, value in pretrained_weights.items():
        #print(key)
        pretrained_weights_new['layoutlm.' + key] = value

    model.load_state_dict(pretrained_weights_new, strict=False)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)
    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:

        train_dataset = OcrDataset(args.data_dir, tokenizer, labels,
                                   "train_v2")

        global_step, tr_loss = train(args, train_dataset, model, tokenizer,
                                     labels)

        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (model.module if hasattr(model, "module") else model
                         )  # Take care of distributed/parallel training

        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        tokenizer = tokenizer_class.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME,
                              recursive=True)))
            logging.getLogger("pytorch_transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split(
                "-")[-1] if len(checkpoints) > 1 else ""
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            result, _ = evaluate(
                args,
                model,
                tokenizer,
                labels,
                mode="dev_v2",
                prefix=global_step,
            )
            if global_step:
                result = {
                    "{}_{}".format(global_step, k): v
                    for k, v in result.items()
                }
            results.update(result)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            for key in sorted(results.keys()):
                writer.write("{} = {}\n".format(key, str(results[key])))

    if args.do_predict and args.local_rank in [-1, 0]:
        tokenizer = tokenizer_class.from_pretrained(
            args.model_name_or_path, do_lower_case=args.do_lower_case)

        model = DocumentClassifier(config, args.img_model_type, args.device,
                                   1536, num_labels)  # 1280

        pretrained_weights = (torch.load(
            os.path.join(args.model_name_or_path, 'pytorch_model.bin')))
        model.load_state_dict(pretrained_weights)

        model.to(args.device)
        evaluate(args, model, tokenizer, labels, mode="dev_v2")
def evaluate(args, model, tokenizer, labels, mode, prefix=""):
    eval_dataset = OcrDataset(args.data_dir, tokenizer, labels, "dev_v2")
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset,
        sampler=eval_sampler,
        batch_size=args.eval_batch_size,
        collate_fn=None,
    )
    model.eval()

    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)

    if mode == "dev_v2":
        dev_loss = 0
        dev_steps = 0
        logit_all, label_all = [], []
        for batch in tqdm(eval_dataloader):
            with torch.no_grad():
                inputs = {
                    "input_ids": batch['input_ids'].to(args.device),
                    "attention_mask": batch['attention_mask'].to(args.device),
                    "image": batch['image'].to(args.device),
                    "label": batch['label'].to(args.device)
                }
                if args.model_type in ["layoutlm"]:
                    inputs["bbox"] = batch['bbox'].to(args.device)
                inputs["token_type_ids"] = (
                    batch['token_type_ids'].to(args.device)
                    if args.model_type in ["bert", "layoutlm"] else None
                )  # RoBERTa don"t use segment_ids

                logit, loss = model(**inputs)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training

                #import ipdb;ipdb.set_trace()
                dev_loss += loss.item()

                logit = logit.argmax(-1).detach().cpu().numpy().tolist()
                label = inputs['label'].detach().cpu().numpy().tolist()

                logit_all.extend(logit)
                label_all.extend(label)

            dev_steps += 1

        acc = accuracy_score(logit_all, label_all)
        f1 = f1_score(logit_all, label_all,
                      average=None)  # 'micro'/'macro'/'weighted'
        macro_f1 = f1_score(logit_all, label_all, average='macro')
        micro_f1 = f1_score(logit_all, label_all, average='micro')

        cm = confusion_matrix(logit_all, label_all)
        #print("Test: acc {:.5f}, f1 {}, cm {}".format(acc, f1, cm))
        print("Test: acc {:.5f}, macro_f1 {}, micro_f1 {}".format(
            acc, macro_f1, micro_f1))
        print("Test: f1 {}".format(f1))
        return dev_loss / dev_steps
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = []

    #if args.num_in_channels == 3:
    #    line_img_transforms.append(imagetransforms.ConvertColor())

    # Always convert color for the augmentations to work (for now)
    # Then alter convert back to grayscale if needed
    line_img_transforms.append(imagetransforms.ConvertColor())

    # Data augmentations (during training only)
    if args.daves_augment:
        line_img_transforms.append(daves_augment.ImageAug())

    if args.synth_input:

        # Randomly rotate image from -2 degrees to +2 degrees
        line_img_transforms.append(
            imagetransforms.Randomize(0.3, imagetransforms.RotateRandom(-2,
                                                                        2)))

        # Choose one of methods to blur/pixel-ify image  (or don't and choose identity)
        line_img_transforms.append(
            imagetransforms.PickOne([
                imagetransforms.TessBlockConv(kernel_val=1, bias_val=1),
                imagetransforms.TessBlockConv(rand=True),
                imagetransforms.Identity(),
            ]))

        aug_cn = iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5)
        line_img_transforms.append(
            imagetransforms.Randomize(0.5, lambda x: aug_cn.augment_image(x)))

        # With some probability, choose one of:
        #   Grayscale:  convert to grayscale and add back into color-image with random alpha
        #   Emboss:  Emboss image with random strength
        #   Invert:  Invert colors of image per-channel
        aug_gray = iaa.Grayscale(alpha=(0.0, 1.0))
        aug_emboss = iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0))
        aug_invert = iaa.Invert(1, per_channel=True)
        aug_invert2 = iaa.Invert(0.1, per_channel=False)
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.PickOne([
                    lambda x: aug_gray.augment_image(x),
                    lambda x: aug_emboss.augment_image(x),
                    lambda x: aug_invert.augment_image(x),
                    lambda x: aug_invert2.augment_image(x)
                ])))

        # Randomly try to crop close to top/bottom and left/right of lines
        # For now we are just guessing (up to 5% of ends and up to 10% of tops/bottoms chopped off)

        if args.tight_crop:
            # To make sure padding is reasonably consistent, we first rsize image to target line height
            # Then add padding to this version of image
            # Below it will get resized again to target line height
            line_img_transforms.append(
                imagetransforms.Randomize(
                    0.9,
                    imagetransforms.Compose([
                        imagetransforms.Scale(new_h=args.line_height),
                        imagetransforms.PadRandom(pxl_max_horizontal=30,
                                                  pxl_max_vertical=10)
                    ])))

        else:
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropHorizontal(.05)))
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropVertical(.1)))

        #line_img_transforms.append(imagetransforms.Randomize(0.2,
        #                                                     imagetransforms.PickOne([imagetransforms.MorphErode(3), imagetransforms.MorphDilate(3)])
        #                                                     ))

    # Make sure to do resize after degrade step above
    line_img_transforms.append(imagetransforms.Scale(new_h=args.line_height))

    if args.cvtGray:
        line_img_transforms.append(imagetransforms.ConvertGray())

    # Only do for grayscale
    if args.num_in_channels == 1:
        line_img_transforms.append(imagetransforms.InvertBlackWhite())

    if args.stripe:
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.AddRandomStripe(val=0,
                                                strip_width_from=1,
                                                strip_width_to=4)))

    line_img_transforms.append(imagetransforms.ToTensor())

    line_img_transforms = imagetransforms.Compose(line_img_transforms)

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    if len(args.datadir) == 1:
        train_dataset = OcrDataset(args.datadir[0], "train",
                                   line_img_transforms)
        validation_dataset = OcrDataset(args.datadir[0], "validation",
                                        line_img_transforms)
    else:
        train_dataset = OcrDatasetUnion(args.datadir, "train",
                                        line_img_transforms)
        validation_dataset = OcrDatasetUnion(args.datadir, "validation",
                                             line_img_transforms)

    if args.test_datadir is not None:
        if args.test_outdir is None:
            print(
                "Error, must specify both --test-datadir and --test-outdir together"
            )
            sys.exit(1)

        if not os.path.exists(args.test_outdir):
            os.makedirs(args.test_outdir)

        line_img_transforms_test = imagetransforms.Compose([
            imagetransforms.Scale(new_h=args.line_height),
            imagetransforms.ToTensor()
        ])
        test_dataset = OcrDataset(args.test_datadir, "test",
                                  line_img_transforms_test)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
        print(
            "Overriding automatically learned alphabet with pre-saved model alphabet"
        )
        if len(args.datadir) == 1:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
        else:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
            for ds in train_dataset.datasets:
                ds.alphabet = model.alphabet
            for ds in validation_dataset.datasets:
                ds.alphabet = model.alphabet

    else:
        model = CnnOcrModel(num_in_channels=args.num_in_channels,
                            input_line_height=args.line_height,
                            rds_line_height=args.rds_line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Setting dataloader after we have a chnae to (maybe!) over-ride the dataset alphabet from a pre-trained model
    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    if args.max_val_size > 0:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset,
                                               max_items=args.max_val_size,
                                               fixed_rand=True),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)
    else:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset, rand=False),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)

    if args.test_datadir is not None:
        test_dataloader = DataLoader(test_dataset,
                                     args.batch_size,
                                     num_workers=0,
                                     sampler=GroupedSampler(test_dataset,
                                                            rand=False),
                                     collate_fn=SortByWidthCollater,
                                     pin_memory=False,
                                     drop_last=False)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr_alpha,
                                 weight_decay=args.weight_decay)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    do_test_write = False
    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Only turn on test-on-testset when cer is starting to get non-random
            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)

                if val_cer < 0.5:
                    do_test_write = True

                if args.test_datadir is not None and (
                        iteration % snapshot_every_n_iterations
                        == 0) and do_test_write:
                    out_hyp_outdomain_file = os.path.join(
                        args.test_outdir,
                        "hyp-%07d.outdomain.utf8" % iteration)
                    out_hyp_indomain_file = os.path.join(
                        args.test_outdir, "hyp-%07d.indomain.utf8" % iteration)
                    out_meta_file = os.path.join(args.test_outdir,
                                                 "hyp-%07d.meta" % iteration)
                    test_on_val_writeout(test_dataloader, model,
                                         out_hyp_outdomain_file)
                    test_on_val_writeout(validation_dataloader, model,
                                         out_hyp_indomain_file)
                    with open(out_meta_file, 'w') as fh_out:
                        fh_out.write("%d,%f,%f,%f\n" %
                                     (iteration, val_cer, val_wer, val_loss))

                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'rtl': args.rtl,
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")