示例#1
0
    def __call__(self, **kwargs):
        self.model.eval()
        input_fn = ImageFeeder(device=self.gpu,
                               keep_padded_tensors=False,
                               parent_feeder=ItemFeeder('img'))
        queries_embed = []
        queries_gt = []
        for batch in tqdm(self.queries_loader, desc='Queries'):
            try:
                output = torch.nn.functional.sigmoid(
                    self.model(input_fn(batch)))
                queries_embed.append(output.data.cpu().numpy())
                queries_gt.append(batch['txt'][0])
            except Exception as ex:
                laia.common.logging.error('Exception processing: {!r}',
                                          batch['id'])
                raise ex
        queries_embed = np.vstack(queries_embed)

        words_embed = []
        words_gt = []
        for batch in tqdm(self.words_loader, desc='Docs'):
            output = torch.nn.functional.sigmoid(self.model(input_fn(batch)))
            words_embed.append(output.data.cpu().numpy())
            words_gt.append(batch['txt'][0])
        words_embed = np.vstack(words_embed)

        NQ = queries_embed.shape[0]
        NW = words_embed.shape[0]
        gap_meter = AveragePrecisionMeter(desc_sort=False)
        map_meter = [AveragePrecisionMeter(desc_sort=False) for _ in range(NQ)]
        distances = cdist(queries_embed, words_embed, metric='braycurtis')
        for i in range(NQ):
            for j in range(NW):
                if queries_gt[i] == words_gt[j]:
                    gap_meter.add(1, 0, 0, distances[i][j])
                    map_meter[i].add(1, 0, 0, distances[i][j])
                else:
                    gap_meter.add(0, 1, 0, distances[i][j])
                    map_meter[i].add(0, 1, 0, distances[i][j])

        g_ap = gap_meter.value
        aps = [m.value for m in map_meter if m.value is not None]
        laia.common.logging.info(
            'Epoch {epochs:4d}, '
            'VA gAP = {gap:5.1%}, '
            'VA mAP = {map:5.1%}, ',
            epochs=kwargs['epoch'],
            gap=g_ap,
            map=np.mean(aps) if len(aps) > 0 else None)
示例#2
0
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_l2_penalty,
        )
    parameters = {
        "model":
        model,
        "criterion":
        None,  # Set automatically by HtrEngineWrapper
        "optimizer":
        optimizer,
        "data_loader":
        tr_ds_loader,
        "batch_input_fn":
        ImageFeeder(device=args.gpu,
                    keep_padded_tensors=args.keep_padded_tensors,
                    parent_feeder=ItemFeeder("img")),
        "batch_target_fn":
        ItemFeeder("txt"),
        "batch_id_fn":
        ItemFeeder("id"),  # Print image ids on exception
        "progress_bar":
        "Train" if args.show_progress_bar else False,
    }
    trainer = Trainer(**parameters)
    trainer.iterations_per_update = args.iterations_per_update

    evaluator = laia.engine.Evaluator(
        model=model,
        data_loader=va_ds_loader,
        batch_input_fn=ImageFeeder(
示例#3
0
        args.train_samples_per_epoch)
    # Validation data (queries and word candidates)
    qr_ds, qr_ds_loader = create_dataset_and_loader(args.img_dir,
                                                    args.qry_txt_table,
                                                    laia.utils.ImageToTensor())
    wd_ds, wd_ds_loader = create_dataset_and_loader(args.img_dir,
                                                    args.doc_txt_table,
                                                    laia.utils.ImageToTensor())

    trainer = laia.engine.Trainer(
        model=model,
        criterion=DortmundBCELoss(),
        optimizer=optimizer,
        data_loader=tr_ds_loader,
        batch_input_fn=ImageFeeder(device=args.gpu,
                                   keep_padded_tensors=False,
                                   parent_feeder=ItemFeeder('img')),
        batch_target_fn=VariableFeeder(device=args.gpu,
                                       parent_feeder=PHOCFeeder(
                                           syms=syms,
                                           levels=args.phoc_levels,
                                           parent_feeder=ItemFeeder('txt'))),
        progress_bar='Train' if args.show_progress_bar else False)
    trainer.iterations_per_update = args.iterations_per_update

    trainer.add_hook(EPOCH_END,
                     Evaluate(model, qr_ds_loader, wd_ds_loader, args.gpu))

    if args.max_epochs and args.max_epochs > 0:
        trainer.add_hook(
            EPOCH_START,
    dataset = TextImageFromTextTableDataset(
        args.txt_table,
        args.img_dirs,
        img_transform=ImageToTensor(),
        txt_transform=transforms.text.ToTensor(syms),
    )

    dataset_loader = ImageDataLoader(
        dataset=dataset,
        image_channels=1,
        batch_size=args.batch_size,
        num_workers=multiprocessing.cpu_count(),
    )

    batch_input_fn = ImageFeeder(device=device,
                                 parent_feeder=ItemFeeder("img"))
    batch_target_fn = ItemFeeder("txt")
    batch_id_fn = ItemFeeder("id")

    if args.score_function == 'cer':
        score_fn = cer_score_dict

    decoder = CTCGreedyDecoder()

    decoded = []
    target = []
    ids = []

    counter = 0

    # Go through all the samples, compute the prediction, get the label
    def __init__(
        self,
        symbols_table,
        phoc_levels,
        train_engine,  # type: Trainer
        valid_engine=None,  # type: Optional[Evaluator]
        check_valid_hook_when=EPOCH_END,  # type: Optional[str]
        valid_hook_condition=None,  # type: Optional[Callable]
        gpu=0,
        exclude_labels=None,
        ignore_missing=False,
        use_new_phoc=False,
        summary_order=(
            "Epoch",
            "TR Loss",
            "VA Loss",
            "VA gAP",
            "VA mAp",
            "TR Time",
            "VA Time",
            "Memory",
        ),  # type: Sequence[str]
    ):
        # type: (...) -> None
        super(PHOCExperiment, self).__init__(
            train_engine,
            valid_engine=valid_engine,
            check_valid_hook_when=check_valid_hook_when,
            valid_hook_condition=valid_hook_condition,
            summary_order=summary_order,
        )

        # If the trainer was created without any criterion, set it properly.
        if not self._tr_engine.criterion:
            self._tr_engine.criterion = torch.nn.BCEWithLogitsLoss()

        # Set trainer's batch_input_fn and batch_target_fn if not already set.
        if not self._tr_engine.batch_input_fn:
            self._tr_engine.set_batch_input_fn(
                ImageFeeder(
                    device=gpu,
                    keep_padded_tensors=False,
                    parent_feeder=ItemFeeder("img"),
                )
            )
        if not self._tr_engine.batch_target_fn:
            self._tr_engine.set_batch_target_fn(
                TensorFeeder(
                    device=gpu,
                    parent_feeder=PHOCFeeder(
                        syms=symbols_table,
                        levels=phoc_levels,
                        ignore_missing=ignore_missing,
                        new_phoc=use_new_phoc,
                        parent_feeder=ItemFeeder("txt"),
                    ),
                )
            )

        self._tr_engine.add_hook(ITER_END, self._train_accumulate_loss)

        if valid_engine:
            # Set batch_input_fn and batch_target_fn if not already set.
            if not self._va_engine.batch_input_fn:
                self._va_engine.set_batch_input_fn(self._tr_engine.batch_input_fn)
            if not self._va_engine.batch_target_fn:
                self._va_engine.set_batch_target_fn(self._tr_engine.batch_target_fn)

            self._va_ap = PairwiseAveragePrecisionMeter(
                metric="braycurtis",
                ignore_singleton=True,
                exclude_labels=exclude_labels,
            )

            self._va_engine.add_hook(ITER_END, self._valid_accumulate_loss)
        else:
            self._va_ap = None
示例#6
0
        model.load_state_dict(model_ckpt)

    model = model.cuda(args.gpu - 1) if args.gpu > 0 else model.cpu()
    logger.info('Model has {} parameters',
                sum(param.data.numel() for param in model.parameters()))

    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_l2_penalty)
    parameters = {
        'model': model,
        'criterion': None,  # Set automatically by HtrEngineWrapper
        'optimizer': optimizer,
        'data_loader': tr_ds_loader,
        'batch_input_fn': ImageFeeder(device=args.gpu,
                                      parent_feeder=ItemFeeder('img')),
        'batch_target_fn': ItemFeeder('txt'),
        'batch_id_fn': ItemFeeder('id'),  # Print image ids on exception
        'progress_bar': 'Train' if args.show_progress_bar else False,
    }
    trainer = Trainer(**parameters)
    trainer.iterations_per_update = args.iterations_per_update

    evaluator = laia.engine.Evaluator(
        model=model,
        data_loader=va_ds_loader,
        batch_input_fn=ImageFeeder(device=args.gpu,
                                   parent_feeder=ItemFeeder('img')),
        batch_target_fn=ItemFeeder('txt'),
        batch_id_fn=ItemFeeder('id'),  # Print image ids on exception
        progress_bar='Valid' if args.show_progress_bar else False)
示例#7
0
            args.learning_rate,
            args.momentum,
            args.weight_l2_penalty,
        )
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_l2_penalty,
        )
    parameters = {
        "model": model,
        "criterion": None,  # Set automatically by HtrEngineWrapper
        "optimizer": optimizer,
        "data_loader": tr_ds_loader,
        "batch_input_fn": ImageFeeder(device=args.gpu, parent_feeder=ItemFeeder("img")),
        "batch_target_fn": ItemFeeder("txt"),
        "batch_id_fn": ItemFeeder("id"),  # Print image ids on exception
        "progress_bar": "Train" if args.show_progress_bar else False,
    }
    trainer = Trainer(**parameters)
    trainer.iterations_per_update = args.iterations_per_update

    evaluator = laia.engine.Evaluator(
        model=model,
        data_loader=va_ds_loader,
        batch_input_fn=ImageFeeder(device=args.gpu, parent_feeder=ItemFeeder("img")),
        batch_target_fn=ItemFeeder("txt"),
        batch_id_fn=ItemFeeder("id"),  # Print image ids on exception
        progress_bar="Valid" if args.show_progress_bar else False,
    )