" ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] # Also add special tokens: # - CTC blank token at index 0 # - Start token at index 1 # - End token at index 2 # - Padding token at index 3 # NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this! return ["<B>", "<S>", "<E>", "<P>", *characters, *iam_characters] if __name__ == "__main__": load_and_print_info(EMNIST)
para_width = crop_shapes[:, 1].max() para_image = Image.new(mode="L", size=(para_width, para_height), color=0) current_height = 0 for line_crop in line_crops: para_image.paste(line_crop, box=(0, current_height)) current_height += line_crop.height return para_image def generate_random_batches(values: List[Any], min_batch_size: int, max_batch_size: int) -> List[List[Any]]: """ Generate random batches of elements in values without replacement and return the list of all batches. Batch sizes can be anything between min_batch_size and max_batch_size including the end points. """ shuffled_values = values.copy() random.shuffle(shuffled_values) start_id = 0 grouped_values_list = [] while start_id < len(shuffled_values): num_values = random.randint(min_batch_size, max_batch_size) grouped_values_list.append(shuffled_values[start_id : start_id + num_values]) start_id += num_values assert sum([len(_) for _ in grouped_values_list]) == len(values) return grouped_values_list if __name__ == "__main__": load_and_print_info(IAMSyntheticParagraphs)
def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims): images = torch.zeros((N, dims[1], dims[2])) labels = [] for n in range(N): label = sentence_generator.generate() images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1]) labels.append(label) return images, labels def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool ) -> np.ndarray: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with <S> and <E> tokens, and padded with the <P> token. """ labels = np.ones((len(strings), length), dtype=np.uint8) * mapping["<P>"] for i, string in enumerate(strings): tokens = list(string) if with_start_end_tokens: tokens = ["<S>", *tokens, "<E>"] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels if __name__ == "__main__": load_and_print_info(EMNISTLines)
crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR) # Embed in the image x, y = 28, 0 # if augment: # x = random.randint(0, (image_width - new_crop_width)) # y = random.randint(0, (IMAGE_HEIGHT - new_crop_height)) image.paste(crop_resized, (x, y)) return image transforms_list = [transforms.Lambda(embed_crop)] if augment: transforms_list += [ transforms.ColorJitter(brightness=(0.8, 1.6)), transforms.RandomAffine( degrees=1, shear=(-30, 20), resample=Image.BILINEAR, ), ] transforms_list += [ transforms.ToTensor(), # transforms.Lambda(lambda x: x - 0.5) ] return transforms.Compose(transforms_list) if __name__ == "__main__": load_and_print_info(IAMLines)
# https://pytorch-lightning.readthedocs.io/en/latest/advanced/multiple_loaders.html#multiple-training-dataloaders # def train_dataloader(self): # return DataLoader( # self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True # ) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member f"Num classes: {len(self.mapping)}\n" f"Dims: {self.dims}\n" f"Output dims: {self.output_dims}\n") if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data if __name__ == "__main__": load_and_print_info(IAMOriginalAndSyntheticParagraphs)
self.data_dir = DOWNLOADED_DATA_DIRNAME self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) self.dims = ( 1, 28, 28 ) # dims are returned when calling `.size()` on this object. self.output_dims = (1, ) self.mapping = list(range(10)) def prepare_data(self): """Download train and test MNIST data from PyTorch canonical source.""" TorchMNIST(self.data_dir, train=True, download=True) TorchMNIST(self.data_dir, train=False, download=True) def setup(self, stage=None): """Split into train, val, test, and set dims.""" mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) self.data_train, self.data_val = random_split(mnist_full, [55000, 5000]) self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) if __name__ == "__main__": load_and_print_info(cast(Type[BaseDataModule], MNIST))
crop_shapes = np.array(_get_property_values("crop_shape")) aspect_ratios = crop_shapes[:, 1] / crop_shapes[:, 0] return { "label_length": { "min": min(_get_property_values("label_length")), "max": max(_get_property_values("label_length")), }, "num_lines": {"min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines"))}, "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)}, "aspect_ratio": {"min": aspect_ratios.min(), "max": aspect_ratios.max()}, } def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" def _crop_filename(id_: str, split: str) -> Path: """Return filename of processed crop.""" return PROCESSED_DATA_DIRNAME / split / f"{id_}.png" def _num_lines(label: str) -> int: """Return number of lines of text in label.""" return label.count("\n") + 1 if __name__ == "__main__": load_and_print_info(IAMParagraphs)
xml_line_elements = xml_root_element.findall("handwritten-part/line") return [_get_line_region_from_xml_element(el) for el in xml_line_elements] def _get_line_region_from_xml_element(xml_line) -> Dict[str, int]: """ Parameters ---------- xml_line xml element that has x, y, width, and height attributes """ word_elements = xml_line.findall("word/cmp") x1s = [int(el.attrib["x"]) for el in word_elements] y1s = [int(el.attrib["y"]) for el in word_elements] x2s = [ int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements ] y2s = [ int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements ] return { "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, } if __name__ == "__main__": load_and_print_info(IAM)