示例#1
0
    def setup(self, stage: str = None) -> None:
        print(f"IAMSyntheticParagraphs.setup({stage}): Loading trainval IAM paragraph regions and lines...")

        if stage == "fit" or stage is None:
            line_crops, line_labels = load_line_crops_and_labels("trainval", PROCESSED_DATA_DIRNAME)
            X, para_labels = generate_synthetic_paragraphs(line_crops=line_crops, line_labels=line_labels)
            Y = convert_strings_to_labels(strings=para_labels, mapping=self.inverse_mapping, length=self.output_dims[0])
            transform = get_transform(image_shape=self.dims[1:], augment=self.augment)  # type: ignore
            self.data_train = BaseDataset(X, Y, transform=transform)
示例#2
0
    def __init__(self):
        data = IAMParagraphs()
        self.mapping = data.mapping
        inv_mapping = data.inverse_mapping
        self.ignore_tokens = [
            inv_mapping["<S>"], inv_mapping["<B>"], inv_mapping["<E>"],
            inv_mapping["<P>"]
        ]
        self.transform = get_transform(image_shape=data.dims[1:],
                                       augment=False)

        with open(CONFIG_AND_WEIGHTS_DIRNAME / "config.json", "r") as file:
            config = json.load(file)
        args = argparse.Namespace(**config)

        model = ResnetTransformer(data_config=data.config(), args=args)
        self.lit_model = TransformerLitModel.load_from_checkpoint(
            checkpoint_path=CONFIG_AND_WEIGHTS_DIRNAME / "model.pt",
            args=args,
            model=model)
        self.lit_model.eval()