def test_segmentation_only_symbols(self):
     x = [1, 1, 1]
     s = CTCGreedyDecoder.compute_segmentation(x)
     e = [0, 3]
     self.assertEqual(e, s)
     x = [1, 2, 3]
     s = CTCGreedyDecoder.compute_segmentation(x)
     e = [0, 1, 2, 3]
     self.assertEqual(e, s)
示例#2
0
 def __init__(
     self,
     decoder: Optional[Callable] = CTCGreedyDecoder(),
     syms: Optional[Union[dict, SymbolsTable]] = None,
     use_symbols: bool = False,
     input_space: str = "<space>",
     output_space: str = " ",
     convert_spaces: bool = False,
     join_string: Optional[str] = None,
     separator: str = " ",
     include_img_ids: bool = True,
 ):
     super().__init__()
     self.decoder = decoder
     self.syms = syms
     self.use_symbols = use_symbols
     if use_symbols:
         assert syms is not None
     self.input_space = input_space
     self.output_space = output_space
     self.convert_spaces = convert_spaces
     if convert_spaces:
         assert use_symbols
     self.join_string = join_string
     self.separator = separator
     self.include_img_ids = include_img_ids
 def test_prob(self):
     x = torch.tensor([[[0.3, 0.6, 0.1]], [[0.6, 0.3, 0.2]]]).log()
     decoder = CTCGreedyDecoder()
     r = decoder(x, segmentation=True)
     e = [[1]]
     self.assertEqual(e, r["hyp"])
     # Check actual loss prob
     paths = torch.tensor(
         [x[0, 0, a] + x[1, 0, b] for a, b in ((0, 1), (1, 0), (1, 1))])
     loss = torch.nn.functional.ctc_loss(x,
                                         torch.tensor(e),
                                         torch.tensor([2]),
                                         torch.tensor([1]),
                                         reduction="none")
     loss_prob = loss.neg().exp()
     path_prob = paths.exp().sum()
     torch.testing.assert_allclose(loss_prob, path_prob)
     # Check 1best prob against loss with input_length = 1
     loss = torch.nn.functional.ctc_loss(x,
                                         torch.tensor(e),
                                         torch.tensor([1]),
                                         torch.tensor([1]),
                                         reduction="none")
     loss_prob = loss.neg().exp()
     torch.testing.assert_allclose(loss_prob,
                                   [p.mean() for p in r["prob"]][0])
 def test_segmentation(self):
     x = [1, 1, 0, 0, 0, 2, 0, 0, 3]
     s = CTCGreedyDecoder.compute_segmentation(x)
     e = [0, 2, 6, 9]
     self.assertEqual(e, s)
     x = [1, 2, 0, 0, 3, 2, 0, 0, 3]
     s = CTCGreedyDecoder.compute_segmentation(x)
     e = [0, 1, 2, 5, 6, 9]
     self.assertEqual(e, s)
     x = [0, 0, 1, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0]
     s = CTCGreedyDecoder.compute_segmentation(x)
     e = [0, 4, 8, 11, 13]
     self.assertEqual(e, s)
     x = [0, 0, 0, 2, 2]
     s = CTCGreedyDecoder.compute_segmentation(x)
     e = [0, 5]
     self.assertEqual(e, s)
 def test(self):
     x = torch.tensor([
         [[1.0, 3.0, -1.0, 0.0]],
         [[-1.0, 2.0, -2.0, 3.0]],
         [[1.0, 5.0, 9.0, 2.0]],
         [[-1.0, -2.0, -3.0, -4.0]],
     ])
     decoder = CTCGreedyDecoder()
     r = decoder(x)
     e = [[1, 3, 2]]
     self.assertEqual(e, r["hyp"])
示例#6
0
    def __init__(
            self,
            train_engine,  # type: Trainer
            valid_engine=None,  # type: Optional[Evaluator]
            check_valid_hook_when=EPOCH_END,  # type: Optional[str]
            valid_hook_condition=None,  # type: Optional[Callable]
            word_delimiters=None,  # type: Optional[Sequence]
            summary_order=(
                "Epoch",
                "TR Loss",
                "VA Loss",
                "TR CER",
                "VA CER",
                "TR WER",
                "VA WER",
                "TR Time",
                "VA Time",
                "Memory",
            ),  # type: Sequence[str]
    ):
        # type: (...) -> None
        super(HTRExperiment, self).__init__(
            train_engine,
            valid_engine=valid_engine,
            check_valid_hook_when=check_valid_hook_when,
            valid_hook_condition=valid_hook_condition,
            summary_order=summary_order,
        )
        self._word_delimiters = word_delimiters

        # If the trainer was created without any criterion,
        # or it is not the CTCLoss, set it properly.
        if not self._tr_engine.criterion:
            self._tr_engine.criterion = CTCLoss()
        elif not isinstance(self._tr_engine.criterion, CTCLoss):
            self.logger.warn("Overriding the criterion of the trainer to CTC.")
            self._tr_engine.criterion = CTCLoss()

        self._ctc_decoder = CTCGreedyDecoder()
        self._tr_cer = SequenceErrorMeter()
        self._tr_wer = SequenceErrorMeter()

        self._tr_engine.add_hook(ITER_END, self._train_update_meters)

        if self._va_engine:
            self._va_cer = SequenceErrorMeter()
            self._va_wer = SequenceErrorMeter()

            self._va_engine.add_hook(ITER_END, self._valid_update_meters)
        else:
            self._va_cer = None
            self._va_wer = None
 def test_batch(self):
     x = torch.tensor([[[0.3, 0.6], [0.5, 0.9]], [[0.6, 0.3],
                                                  [0.6, 0.9]]]).log()
     decoder = CTCGreedyDecoder()
     r = decoder(x, segmentation=True)
     e = [[1], [1]]
     self.assertEqual(e, r["hyp"])
     # note: checking with ctc_loss does not work for every x
     loss = torch.nn.functional.ctc_loss(
         x,
         torch.tensor(e),
         torch.tensor([1, 1]),
         torch.tensor([1, 1]),
         reduction="none",
     )
     e = loss.neg().exp()
     r = torch.tensor([p.mean() for p in r["prob"]])
     torch.testing.assert_allclose(r, e)
示例#8
0
    if args.gpu > 0:
        model = model.cuda(args.gpu - 1)
    else:
        model = model.cpu()

    dataset = TextImageFromTextTableDataset(
        args.gt_file,
        args.img_dir,
        img_transform=ImageToTensor(),
        txt_transform=TextToTensor(syms),
    )
    dataset_loader = ImageDataLoader(dataset=dataset,
                                     image_channels=1,
                                     num_workers=8)

    decoder = CTCGreedyDecoder()
    with torch.cuda.device(args.gpu - 1):
        for batch in dataset_loader:
            if args.gpu > 0:
                x = batch["img"].data.cuda(args.gpu - 1)
            else:
                x = batch["img"].data.cpu()
            y = model(torch.autograd.Variable(x))
            y = decoder(y)
            if args.output_symbols:
                y = list(map(lambda i: syms[i], y[0]))
            else:
                y = list(map(lambda i: str(i), y[0]))
            print("{} {}".format(batch["id"][0], " ".join(y)),
                  file=args.output)
    def __init__(
        self,
        model,  # type: torch.nn.Module
        criterion,  # type: Optional[Callable]
        optimizer,  # type: torch.optim.Optimizer
        data_loader=None,  # type: Optional[Iterable]
        batch_input_fn=None,  # type: Optional[Callable]
        batch_target_fn=None,  # type: Optional[Callable]
        batch_id_fn=None,  # type: Optional[Callable]
        progress_bar=None,  # type: Optional[Union[bool, str]]
        iterations_per_update=1,  # type: int
        cv_number=None,
        use_baseline=None,
        use_cl=None,
        use_transfer=None,
        use_semi_supervised=None,
        threshold_score_semi_supervised=None,
        data_semi_supervised_loader=None,
        epoch_frequency_semi_supervision=None,
        syms=None,
        original_data_loader=None,
    ):
        # type: (...) -> None
        super(Trainer, self).__init__(model=model,
                                      data_loader=data_loader,
                                      batch_input_fn=batch_input_fn,
                                      batch_target_fn=batch_target_fn,
                                      batch_id_fn=batch_id_fn,
                                      progress_bar=progress_bar,
                                      use_baseline=use_baseline,
                                      use_cl=use_cl,
                                      use_transfer=use_transfer)
        self._criterion = criterion
        self._optimizer = optimizer
        self._iterations_per_update = iterations_per_update
        self._updates = 0
        self._cv_number = cv_number
        self._progress_bar = progress_bar

        self.data_loader = data_loader

        self.use_semi_supervised = use_semi_supervised
        self.threshold_score_semi_supervised = threshold_score_semi_supervised
        self.data_semi_supervised_loader = data_semi_supervised_loader
        self.epoch_frequency_semi_supervision = epoch_frequency_semi_supervision
        self.counter_epoch_semi_supervision = 0
        self.semi_supervision_started = False
        self.original_dataset = {
            'ids': data_loader.dataset._ids,
            'imgs': data_loader.dataset._imgs,
            'txts': data_loader.dataset._txts
        }
        self.decoder = CTCGreedyDecoder()
        self.syms = syms
        self.original_data_loader = original_data_loader

        # Load Spell Checker
        self.sym_spell = SymSpell(max_dictionary_edit_distance=5,
                                  prefix_length=7)
        dict_name = 'de_50k.txt'  #"frequency_dictionary_en_82_765.txt"
        if not self.sym_spell.load_dictionary(
                dict_name, term_index=0, count_index=1, encoding='utf-8-sig'):
            print("error loading spell checker")
 def test_segmentation_only_zeros(self):
     x = [0, 0, 0]
     s = CTCGreedyDecoder.compute_segmentation(x)
     e = [0, 3]
     self.assertEqual(e, s)
 def test_segmentation_empty(self):
     s = CTCGreedyDecoder.compute_segmentation([])
     e = []
     self.assertEqual(e, s)