def create_cell_dataset( folders, scale=0.125, crop=(128, 128), n_streams=12, batch_size=16, cell_prominence_min=0.4, cell_prominence_max=float('inf'), ): def validate_crop(tensor): mean = tensor[1:].sum(dim=0).mean() if cell_prominence_min < mean < cell_prominence_max: return tensor[:1] return None images = Dataset.ImageIterableDataset( folders, transforms.Compose(( transforms.GaussianBlur(3, (.01, 1.)), data_aug.flip_scale_pipeline(scale, ), #data_aug.pipeline(scale, degrees=0, noise_p=0.01), )), Dataset.CropGenerator(crop, validate_crop=validate_crop), n_streams=n_streams, # large memory impact indices=[*range(0, 204), *range(306, 2526)], ) return DataLoader(images, batch_size=batch_size, drop_last=True)
def create_ground_truth_dataset(lower_bound=.4, upper_bound=1.): validate_crop = lambda t: lower_bound <= t[1].mean() <= upper_bound images = Dataset.ImageIterableDataset( [dirs.images.images, dirs.images.cell], data_aug.pipeline(0.25), Dataset.CropGenerator((256, 256), validate_crop=validate_crop), n_streams=6, # large memory impact indices=[*range(600, 2040)], ) return images
def create_validation_dataloader(batch_size=8, lower_bound=.6, upper_bound=.99): validate_crop = lambda t: lower_bound <= t[1].mean() <= upper_bound val = Dataset.ImageIterableDataset( [dirs.images.images, dirs.images.cell], data_aug.flip_scale_pipeline(0.25), Dataset.CropGenerator((256, 256), validate_crop=validate_crop), n_streams=6, # large memory impact indices=list(range(350, 550)), ) return DataLoader(val, batch_size=batch_size, drop_last=True)
def create_pseudolabel_dataset(parents, add_original=True, prominence_filtering=True): images = Dataset.ImageIterableDataset( [dirs.images.images, dirs.images.modified_cell], data_aug.pipeline(0.25, degrees=0), Dataset.CropGenerator((256, 256)), n_streams=4, # large memory impact indices=[*range(0, 204), *range(650, 2526)], ) if prominence_filtering: images = Dataset.IterableDatasetFilter( # filter result before parents images, lambda t: t if (.1 < t[1].mean()) else None) return create_self_training_dataset(images, parents, add_original=add_original)
def create_full_size_dataset(): return Dataset.ImageIterableDataset( [dirs.images.images, dirs.images.modified_cell], data_aug.scale_pipeline(0.25), lambda t: iter((t, )), # dummy cropper n_streams=1, indices=[*range(100), *range(306, 600), *range(2300, 2525)])
def create_generative_dataset( parents, im_size=256, generator=dirs.models / "SelfTrainingGenerator.pt", ): if isinstance(generator, (str, Path)): generator = torch.load(generator) #generator.eval() dataset = Dataset.GenerativeDataset(generator, 1) dataset = Dataset.IterableDatasetFilter(dataset, lambda t: t[0]) # no batch dataset = Dataset.IterableDatasetFilter(dataset, lambda t: t[:1]) # drop mask dataset = Dataset.IterableDatasetFilter(dataset, transforms.Resize(im_size)) dataset = create_self_training_dataset(dataset, parents, add_original=False) return dataset
def create_generic_dataloader(crop=True, modified_cell=True): if crop: cropper = Dataset.CropGenerator((256, 256)) else: cropper = lambda img: iter((img, )) folders = [dirs.images.images] if modified_cell: folders.append(dirs.images.modified_cell) else: folders.append(dirs.images.cell) images = Dataset.ImageIterableDataset( folders, data_aug.scale_pipeline(0.25), cropper, n_streams=1, # large memory impact indices=[*range(350, 2526)], ) return DataLoader(images, batch_size=1)
def raw_data( scale=0.125, crop=(128, 128), n_streams=12, batch_size=16, cell_prominence_min=0.3, cell_prominence_max=0.9, ): def validate_crop(tensor): mean = tensor[1:].sum(dim=0).mean() if cell_prominence_min < mean < cell_prominence_max: return tensor[:1] return None images = Dataset.ImageIterableDataset( [dirs.images.images, dirs.images.modified_cell], data_aug.flip_scale_pipeline(scale), Dataset.CropGenerator(crop, validate_crop=validate_crop), n_streams=n_streams, # large memory impact indices=[*range(0, 204), *range(306, 2526)], ) return images
def create_self_training_dataset( data, parents, add_original=False, prominence_filtering=True, ): dataset = Dataset.SelfTrainingIterableDatasetWrapper( data, parents, add_original=add_original) if prominence_filtering: dataset = Dataset.IterableDatasetFilter( # filter result after parents dataset, lambda t: t if (.1 < t[1].mean() < .99) else None) dataset = Dataset.IterableDatasetFilter(dataset, Dataset.ImageEntropyFilter()) dataset = Dataset.IterableDatasetFilter( dataset, lambda t: torch.cat((t[:1], (t[1:] > .6).float()))) dataset = Dataset.RepeatBufferIterableDataset(dataset, buffer_size=2048, repeat_prob=.8) return dataset
## Create Dataset ## data = [] if use_raw_data: data.append(create_ground_truth_dataset()) if use_parents or use_generative: parents = SelfTraining.Parents.load_parents() if use_parents: # use parents on unlabeled data data.append( create_pseudolabel_dataset(parents, add_original=True, prominence_filtering=False)) if use_generative: # use parents on generative data data.append(create_generative_dataset(parents)) if len(data) > 1: data = Dataset.IterableDatasetMerger(data) else: data = data[0] dataloader = DataLoader(data, batch_size=16, drop_last=True) ## Create Main Network ## save_name = dirs.models / "SegmentationUNet.pt" net = create_unet(scale_size=2, dropout_p=.4, save_name=save_name) print("Traning Network", net) train(net, dataloader, save_name, epochs=400, decay=.1) print("exit.")