def data_gen(self, lmdb_path, data_root, annotation, masked_language_model=True, transform=None): dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width']) sampler = ClusterRandomSampler(dataset, self.batch_size, True) collate_fn = Collator(masked_language_model) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=False, drop_last=False, **self.config['dataloader']) return gen
def data_gen(self, lmdb_paths, data_root, annotation, masked_language_model=True, transform=None, is_train=False): datasets = [] for lmdb_path in lmdb_paths: dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width'], separate=self.config['dataset']['separate'], batch_size=self.batch_size, is_padding=self.is_padding) datasets.append(dataset) if len(self.train_lmdb) > 1: dataset = torch.utils.data.ConcatDataset(datasets) if self.is_padding: sampler = None else: sampler = ClusterRandomSampler(dataset, self.batch_size, True) collate_fn = Collator(masked_language_model) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=is_train, drop_last=self.model.seq_modeling == 'crnn', **self.config['dataloader']) return gen