def setup(self, stage: str = None) -> None: print(f"IAMSyntheticParagraphs.setup({stage}): Loading trainval IAM paragraph regions and lines...") if stage == "fit" or stage is None: line_crops, line_labels = load_line_crops_and_labels("trainval", PROCESSED_DATA_DIRNAME) X, para_labels = generate_synthetic_paragraphs(line_crops=line_crops, line_labels=line_labels) Y = convert_strings_to_labels(strings=para_labels, mapping=self.inverse_mapping, length=self.output_dims[0]) transform = get_transform(image_shape=self.dims[1:], augment=self.augment) # type: ignore self.data_train = BaseDataset(X, Y, transform=transform)
def __init__(self): data = IAMParagraphs() self.mapping = data.mapping inv_mapping = data.inverse_mapping self.ignore_tokens = [ inv_mapping["<S>"], inv_mapping["<B>"], inv_mapping["<E>"], inv_mapping["<P>"] ] self.transform = get_transform(image_shape=data.dims[1:], augment=False) with open(CONFIG_AND_WEIGHTS_DIRNAME / "config.json", "r") as file: config = json.load(file) args = argparse.Namespace(**config) model = ResnetTransformer(data_config=data.config(), args=args) self.lit_model = TransformerLitModel.load_from_checkpoint( checkpoint_path=CONFIG_AND_WEIGHTS_DIRNAME / "model.pt", args=args, model=model) self.lit_model.eval()