def __init__(self, config): super(ForegroundUnet, self).__init__() # CONFIGS self.input_shape = tuple(config["INPUT_SHAPE"]) self.output_shape = tuple(config["OUTPUT_SHAPE"]) self.voxel_size = tuple(config["VOXEL_SIZE"]) self.num_fmaps = config["NUM_FMAPS_FOREGROUND"] self.fmap_inc_factors = config["FMAP_INC_FACTORS_FOREGROUND"] self.downsample_factors = tuple( tuple(x) for x in config["DOWNSAMPLE_FACTORS"]) self.kernel_size_up = config["KERNEL_SIZE_UP"] self.activation = config["ACTIVATION"] # LAYERS self.unet = UNet( in_channels=2, num_fmaps=self.num_fmaps, fmap_inc_factors=self.fmap_inc_factors, downsample_factors=self.downsample_factors, kernel_size_down=self.kernel_size_up, kernel_size_up=self.kernel_size_up, activation=self.activation, fov=tuple(a * b for a, b in zip(self.input_shape, self.voxel_size)), voxel_size=self.voxel_size, num_heads=1, constant_upsample=True, ) self.conv_layer = ConvPass( in_channels=self.num_fmaps, out_channels=1, kernel_sizes=[[1, 1, 1]], activation="Sigmoid", )
def __init__(self, config): super(EmbeddingUnet, self).__init__() # CONFIGS self.input_shape = tuple(config["INPUT_SHAPE"]) self.output_shape = tuple(config["OUTPUT_SHAPE"]) self.embedding_dims = config["EMBEDDING_DIMS"] self.num_fmaps = config["NUM_FMAPS_EMBEDDING"] self.fmap_inc_factors = config["FMAP_INC_FACTORS_EMBEDDING"] self.downsample_factors = tuple( tuple(x) for x in config["DOWNSAMPLE_FACTORS"]) self.kernel_size_up = config["KERNEL_SIZE_UP"] self.voxel_size = config["VOXEL_SIZE"] self.activation = config["ACTIVATION"] self.normalize_embeddings = config["NORMALIZE_EMBEDDINGS"] self.aux_task = config["AUX_TASK"] self.neighborhood = config["AUX_NEIGHBORHOOD"] # LAYERS # UNET self.unet = UNet( in_channels=2, num_fmaps=self.num_fmaps, fmap_inc_factors=self.fmap_inc_factors, downsample_factors=self.downsample_factors, kernel_size_down=self.kernel_size_up, kernel_size_up=self.kernel_size_up, activation=self.activation, fov=self.input_shape, num_heads=1, constant_upsample=True, ) # FINAL CONV LAYER self.conv_layer = ConvPass( in_channels=self.num_fmaps, out_channels=self.embedding_dims, kernel_sizes=[[1 for _ in self.input_shape]], activation="Tanh", ) # AUX LAYER if self.aux_task: self.aux_layer = ConvPass( in_channels=self.num_fmaps, out_channels=self.neighborhood, kernel_sizes=[[1 for _ in self.input_shape]], activation="Tanh", )
logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # cages and render parameters cage1 = Cage("../data/example_cage", 1) psf = GaussianPSF(intensity=0.125, sigma=(1.0, 1.0)) min_density = 1e-5 max_density = 1e-5 # create model, loss, and optimizer unet = UNet( in_channels=1, num_fmaps=24, # this needs to be increased later (24) fmap_inc_factor=4, # this needs to be increased later (3) downsample_factors=[ [1, 2, 2], [1, 2, 2], [1, 2, 2], ], kernel_size_down=[[[3, 3, 3], [3, 3, 3]]] * 4, kernel_size_up=[[[3, 3, 3], [3, 3, 3]]] * 3, padding='valid') model = torch.nn.Sequential(unet, ConvPass(24, 1, [(1, 1, 1)], activation='Sigmoid')) loss = torch.nn.BCELoss() optimizer = torch.optim.Adam(model.parameters()) # declare gunpowder arrays raw = gp.ArrayKey('RAW') seg = gp.ArrayKey('SEGMENTATION') out_cage_map = gp.ArrayKey('OUT_CAGE_MAP') out_density_map = gp.ArrayKey('OUT_DENSITY_MAP')
def train_until(max_iteration): in_channels = 1 num_fmaps = 12 fmap_inc_factors = 6 downsample_factors = [(1, 3, 3), (1, 3, 3), (3, 3, 3)] unet = UNet(in_channels, num_fmaps, fmap_inc_factors, downsample_factors, constant_upsample=True) model = Convolve(unet, 12, 1) loss = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-6) # start of gunpowder part: raw = gp.ArrayKey('RAW') points = gp.GraphKey('POINTS') groundtruth = gp.ArrayKey('RASTER') prediction = gp.ArrayKey('PRED_POINT') grad = gp.ArrayKey('GRADIENT') voxel_size = gp.Coordinate((40, 4, 4)) input_shape = (96, 430, 430) output_shape = (60, 162, 162) input_size = gp.Coordinate(input_shape) * voxel_size output_size = gp.Coordinate(output_shape) * voxel_size request = gp.BatchRequest() request.add(raw, input_size) request.add(points, output_size) request.add(groundtruth, output_size) request.add(prediction, output_size) request.add(grad, output_size) pos_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddCenterPoint(points, raw) + gp.Pad(raw, None) + gp.RandomLocation(ensure_nonempty=points) for filename in pos_samples) + gp.RandomProvider() neg_sources = tuple( gp.ZarrSource(filename, {raw: 'volumes/raw'}, {raw: gp.ArraySpec(interpolatable=True)}) + AddNoPoint(points, raw) + gp.RandomLocation() for filename in neg_samples) + gp.RandomProvider() data_sources = (pos_sources, neg_sources) data_sources += gp.RandomProvider(probabilities=[0.9, 0.1]) data_sources += gp.Normalize(raw) train_pipeline = data_sources train_pipeline += gp.ElasticAugment(control_point_spacing=[4, 40, 40], jitter_sigma=[0, 2, 2], rotation_interval=[0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, max_misalign=10, subsample=8) train_pipeline += gp.SimpleAugment(transpose_only=[1, 2]) train_pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, \ z_section_wise=True) train_pipeline += gp.RasterizePoints( points, groundtruth, array_spec=gp.ArraySpec(voxel_size=voxel_size), settings=gp.RasterizationSettings(radius=(100, 100, 100), mode='peak')) train_pipeline += gp.PreCache(cache_size=40, num_workers=10) train_pipeline += Reshape(raw, (1, 1) + input_shape) train_pipeline += Reshape(groundtruth, (1, 1) + output_shape) train_pipeline += gp_torch.Train(model=model, loss=loss, optimizer=optimizer, inputs={'x': raw}, outputs={0: prediction}, loss_inputs={ 0: prediction, 1: groundtruth }, gradients={0: grad}, save_every=1000, log_dir='log') train_pipeline += Reshape(raw, input_shape) train_pipeline += Reshape(groundtruth, output_shape) train_pipeline += Reshape(prediction, output_shape) train_pipeline += Reshape(grad, output_shape) train_pipeline += gp.Snapshot( { raw: 'volumes/raw', groundtruth: 'volumes/groundtruth', prediction: 'volumes/prediction', grad: 'volumes/gradient' }, every=500, output_filename='test_{iteration}.hdf') train_pipeline += gp.PrintProfilingStats(every=10) with gp.build(train_pipeline): for i in range(max_iteration): train_pipeline.request_batch(request)
def train_until(max_iteration): in_channels = 1 num_fmaps = 12 fmap_inc_factors = 6 downsample_factors = [(2, 2, 2), (2, 2, 2), (3, 3, 3)] unet = UNet(in_channels, num_fmaps, fmap_inc_factors, downsample_factors) model = Convolve(unet, num_fmaps, 3) loss = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.5e-4, betas=(0.95, 0.999)) test_input_shape = Coordinate((196, ) * 3) test_output_shape = Coordinate((84, ) * 3) raw = ArrayKey('RAW') labels = ArrayKey('GT_LABELS') labels_mask = ArrayKey('GT_LABELS_MASK') affs = ArrayKey('PREDICTED_AFFS') gt_affs = ArrayKey('GT_AFFS') gt_affs_scale = ArrayKey('GT_AFFS_SCALE') affs_gradient = ArrayKey('AFFS_GRADIENT') voxel_size = Coordinate((8, ) * 3) input_size = Coordinate(test_input_shape) * voxel_size output_size = Coordinate(test_output_shape) * voxel_size #max labels padding calculated labels_padding = Coordinate((376, 536, 536)) request = BatchRequest() request.add(raw, input_size) request.add(labels, output_size) request.add(labels_mask, output_size) request.add(gt_affs, output_size) request.add(gt_affs_scale, output_size) snapshot_request = BatchRequest({affs: request[gt_affs]}) data_sources = tuple( ZarrSource( os.path.join(data_dir, sample), { raw: 'volumes/raw', labels: 'volumes/labels/neuron_ids', labels_mask: 'volumes/labels/mask', }, { raw: ArraySpec(interpolatable=True), labels: ArraySpec(interpolatable=False), labels_mask: ArraySpec(interpolatable=False) }) + Normalize(raw) + Pad(raw, None) + Pad(labels, labels_padding) + Pad(labels_mask, labels_padding) + RandomLocation(min_masked=0.5, mask=labels_mask) for sample in samples) train_pipeline = data_sources train_pipeline += RandomProvider(probabilities=probabilities) train_pipeline += ElasticAugment(control_point_spacing=[40, 40, 40], jitter_sigma=[0, 0, 0], rotation_interval=[0, math.pi / 2.0], prob_slip=0, prob_shift=0, max_misalign=0, subsample=8) train_pipeline += SimpleAugment() train_pipeline += ElasticAugment(control_point_spacing=[40, 40, 40], jitter_sigma=[2, 2, 2], rotation_interval=[0, math.pi / 2.0], prob_slip=0.01, prob_shift=0.01, max_misalign=1, subsample=8) train_pipeline += IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) train_pipeline += GrowBoundary(labels, labels_mask, steps=1) train_pipeline += AddAffinities(neighborhood, labels=labels, affinities=gt_affs) train_pipeline += BalanceLabels(gt_affs, gt_affs_scale) train_pipeline += IntensityScaleShift(raw, 2, -1) train_pipeline += Normalize(gt_affs) train_pipeline += Unsqueeze([raw, gt_affs]) train_pipeline += Unsqueeze([raw]) train_pipeline += PreCache(cache_size=40, num_workers=10) train_pipeline += Train(model=model, loss=loss, optimizer=optimizer, inputs={'x': raw}, loss_inputs={ 0: affs, 1: gt_affs }, outputs={0: affs}, save_every=1000, log_dir='log') train_pipeline += Squeeze([raw]) train_pipeline += Squeeze([raw, gt_affs, affs]) train_pipeline += IntensityScaleShift(raw, 0.5, 0.5) train_pipeline += Snapshot( { raw: 'volumes/raw', labels: 'volumes/labels/neuron_ids', gt_affs: 'volumes/gt_affinities', affs: 'volumes/pred_affinities', labels_mask: 'volumes/labels/mask' }, dataset_dtypes={ labels: np.uint64, gt_affs: np.float32 }, every=1, output_filename='batch_{iteration}.zarr', additional_request=snapshot_request) train_pipeline += PrintProfilingStats(every=1) with build(train_pipeline) as b: for i in range(max_iteration): b.request_batch(request)