def predict(checkpoint, iterations):

    print("checkpoint: ", checkpoint)

    labels_key = ArrayKey('LABELS')
    gt_affs_key = ArrayKey('GT_AFFINITIES')
    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')
    sample_z_key = ArrayKey("SAMPLE_Z")
    # broadcast_key = ArrayKey("BROADCAST")
    # pred_logits_key = ArrayKey("PRED_LOGITS")
    # sample_out_key = ArrayKey("SAMPLE_OUT")
    # debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    sample_shape = Coordinate((1, 1, 6)) * voxel_size
    # debug_shape = Coordinate((1, 1, 5)) * voxel_size

    print("input_size: ", input_shape)
    print("output_size: ", output_shape)

    request = BatchRequest()
    request.add(labels_key, output_shape)
    request.add(gt_affs_key, input_shape)
    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(pred_affinities_key, output_shape)
    # request.add(broadcast_key, output_shape)
    request.add(sample_z_key, sample_shape)
    # request.add(pred_logits_key, output_shape)
    # request.add(sample_out_key, sample_shape)
    # request.add(debug_key, debug_shape)

    dataset_names = {
        labels_key: 'volumes/labels',
    }

    array_specs = {labels_key: ArraySpec(interpolatable=False)}

    pipeline = tuple(
        Hdf5Source(os.path.join(data_dir, sample + '.hdf'),
                   datasets=dataset_names,
                   array_specs=array_specs) + Pad(labels_key, None)
        # Pad(merged_labels_key[i], None) for i in range(num_merges) # don't know why this doesn't work
        for sample in samples)

    pipeline += (
        # Pad(raw_key, size=None) +
        # Crop(raw_key, read_roi) +
        #Normalize(raw_key) +
        SequentialProvider() +
        AddAffinities(affinity_neighborhood=neighborhood,
                      labels=labels_key,
                      affinities=raw_affs_key) +
        AddJoinedAffinities(input_affinities=raw_affs_key,
                            joined_affinities=raw_joined_affs_key) +
        AddRealism(joined_affinities=raw_joined_affs_key,
                   raw=raw_key,
                   sp=0.25,
                   sigma=1,
                   contrast=0.7) +
        GrowBoundary(labels_key, steps=1, only_xy=True) +
        AddAffinities(affinity_neighborhood=neighborhood,
                      labels=labels_key,
                      affinities=gt_affs_key) +
        PreCache(cache_size=32, num_workers=8) +
        IntensityScaleShift(raw_key, 2, -1) + Predict(
            checkpoint=os.path.join(setup_dir,
                                    'train_net_checkpoint_%d' % checkpoint),
            inputs={config['raw']: raw_key},
            outputs={
                config['pred_affs']: pred_affinities_key,
                config['sample_z']: sample_z_key,
                # config['broadcast']: broadcast_key,
                # config['pred_logits']: pred_logits_key,
                # config['sample_out']: sample_out_key,
                # config['debug']: debug_key
            },
            graph=os.path.join(setup_dir, 'predict_net.meta')) +
        IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5) + Snapshot(
            dataset_names={
                # labels_key: 'volumes/labels',
                gt_affs_key: 'volumes/gt_affs',
                # raw_key: 'volumes/raw',
                pred_affinities_key: 'volumes/pred_affs',
                # broadcast_key: 'volumes/broadcast',
                # sample_z_key: 'volumes/sample_z',
                # pred_logits_key: 'volumes/pred_logits',
                # sample_out_key: 'volumes/sample_out'
            },
            output_filename='prob_unet/' + setup_name +
            '/prediction_A_{id}.hdf',
            every=1,
            dataset_dtypes={
                # labels_key: np.uint16,
                gt_affs_key: np.float32,
                pred_affinities_key: np.float32,
                # broadcast_key: np.float32,
                # sample_z_key: np.float32,
                # pred_logits_key: np.float32,
                # sample_out_key: np.float32
            })
        # PrintProfilingStats(every=20)
    )

    print("Starting prediction...")
    with build(pipeline) as p:
        for i in range(iterations):
            req = p.request_batch(request)
            # sample_z = req[sample_z_key].data
            # broadcast_sample = req[broadcast_key].data
            # sample_out = req[sample_out_key].data
            # debug = req[debug_key].data
            # print("debug", debug)

            # print("sample_z: ", sample_z)
            # print("sample_out:", sample_out)
            # print("Z - 0")
            # print(np.unique(broadcast_sample[0, 0, :, :, :]))
            # print("Z - 1")
            # print(np.unique(broadcast_sample[0, 1, :, :, :]))
            # print("Z - 2")
            # print(np.unique(broadcast_sample[0, 2, :, :, :]))
            # print("Z - 3")
            # print(np.unique(broadcast_sample[0, 3, :, :, :]))
            # print("Z - 4")
            # print(np.unique(broadcast_sample[0, 4, :, :, :]))
            # print("Z - 5")
            # print(np.unique(broadcast_sample[0, 5, :, :, :]))
    print("Prediction finished")
Ejemplo n.º 2
0
def train(iterations):

	# tf.reset_default_graph()
	if tf.train.latest_checkpoint('.'):
		trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
	else:
		trained_until = 0
	if trained_until >= iterations:
		return

	if trained_until < phase_switch and iterations > phase_switch:
		train(phase_switch)

	phase = 'euclid' if iterations <= phase_switch else 'malis'
	print("Training in phase %s until %i"%(phase, iterations))
		
	# define array-keys
	labels_key = ArrayKey('LABELS')	
	raw_affs_key = ArrayKey('RAW_AFFINITIES')
	raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
	raw_key = ArrayKey('RAW')
	
	merged_labels_keys = []
	# merged_affs_keys = []
	picked_labels_key = ArrayKey('PICKED_RANDOM_LABEL')

	affs_neg_key = ArrayKey('AFFINITIES')
	affs_pos_key = ArrayKey('AFFINITIES_OPP')
	joined_affs_neg_key = ArrayKey('JOINED_AFFINITIES')
	joined_affs_pos_key = ArrayKey('JOINED_AFFINITIES_OPP')

	num_merges = 3
	for i in range(num_merges):
		merged_labels_keys.append(ArrayKey('MERGED_LABELS_%i'%(i+1)))


	gt_affs_out_key = ArrayKey('GT_AFFINITIES')
	gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
	gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')
	gt_affs_scale_key = ArrayKey('GT_AFFINITIES_SCALE')
	
	pred_affs_key = ArrayKey('PRED_AFFS')
	pred_affs_gradient_key = ArrayKey('PRED_AFFS_GRADIENT')

	sample_z_key = ArrayKey("SAMPLE_Z")
	broadcast_key = ArrayKey("BROADCAST")
	pred_logits_key = ArrayKey("PRED_LOGITS")
	sample_out_key = ArrayKey("SAMPLE_OUT")
	debug_key = ArrayKey("DEBUG")

	voxel_size = Coordinate((1, 1, 1))
	input_shape = Coordinate(config['input_shape']) * voxel_size
	input_affs_shape = Coordinate([i + 1 for i in config['input_shape']]) * voxel_size
	output_shape = Coordinate(config['output_shape']) * voxel_size
	output_affs_shape = Coordinate([i + 1 for i in config['output_shape']]) * voxel_size
	sample_shape = Coordinate((1, 1, config['latent_dims'])) * voxel_size
	debug_shape = Coordinate((1, 1, 5)) * voxel_size

	print ("input_shape: ", input_shape)
	print ("input_affs_shape: ", input_affs_shape)
	print ("output_shape: ", output_shape)
	print ("output_affs_shape: ", output_affs_shape)

	request = BatchRequest()
	request.add(labels_key, input_shape)

	request.add(raw_affs_key, input_shape)
	request.add(raw_joined_affs_key, input_shape)
	request.add(raw_key, input_shape)

	for i in range(num_merges): 
		request.add(merged_labels_keys[i], input_shape)
	request.add(picked_labels_key, output_shape)		

	request.add(gt_affs_out_key, output_shape)
	request.add(gt_affs_in_key, input_shape)
	request.add(gt_affs_mask_key, output_shape)
	request.add(gt_affs_scale_key, output_shape)

	request.add(pred_affs_key, output_shape)
	request.add(pred_affs_gradient_key, output_shape)

	request.add(broadcast_key, output_shape)
	request.add(sample_z_key, sample_shape)
	request.add(pred_logits_key, output_shape)
	request.add(sample_out_key, sample_shape)
	request.add(debug_key, debug_shape)

	dataset_names = {
		labels_key: 'volumes/labels',
	}

	array_specs = {
		labels_key: ArraySpec(interpolatable=False)
	}

	for i in range(num_merges):
		dataset_names[merged_labels_keys[i]] = 'volumes/merged_labels_%i'%(i+1)
		array_specs[merged_labels_keys[i]] = ArraySpec(interpolatable=False)

	pipeline = tuple(
		Hdf5Source(
            os.path.join(data_dir, sample + '.hdf'),
            datasets = dataset_names,
            array_specs = array_specs
        ) +
        Pad(labels_key, None) +
        Pad(merged_labels_keys[0], None) +
        Pad(merged_labels_keys[1], None) +
        Pad(merged_labels_keys[2], None)
        # Pad(merged_labels_key[i], None) for i in range(num_merges) # don't know why this doesn't work
        for sample in samples
	)

	pipeline += RandomProvider()

	pipeline += AddAffinities(
			affinity_neighborhood=neighborhood,
			labels=labels_key,
			affinities=raw_affs_key)

	pipeline += AddJoinedAffinities(
			input_affinities=raw_affs_key,
			joined_affinities=raw_joined_affs_key)

	pipeline += AddRealism(
			joined_affinities = raw_joined_affs_key,
			raw = raw_key,
			sp=0.25,
			sigma=1,
			contrast=0.7)

	if phase == "euclid":

		pipeline += PickRandomLabel(
				input_labels = [labels_key]+ merged_labels_keys,
				output_label=picked_labels_key,
				probabilities=[1, 0, 0, 0])

	else: 

		pipeline += PickRandomLabel(
				input_labels = [labels_key] + merged_labels_keys,
				output_label=picked_labels_key,
				probabilities=[0.25, 0.25, 0.25, 0.25])

		pipeline += RenumberConnectedComponents(
				labels=picked_labels_key)


	pipeline += GrowBoundary(picked_labels_key, steps=1, only_xy=True)

	pipeline += AddAffinities(
			affinity_neighborhood=neighborhood,
			labels=picked_labels_key,
			affinities=gt_affs_in_key)

	pipeline += AddAffinities(
			affinity_neighborhood=neighborhood,
			labels=picked_labels_key,
			affinities=gt_affs_out_key,
			affinities_mask=gt_affs_mask_key)

	pipeline += BalanceLabels(
			labels=gt_affs_out_key,
			scales=gt_affs_scale_key)

	pipeline += DefectAugment(
			intensities=raw_key,
			prob_missing=0.03,
			prob_low_contrast=0.01,
			contrast_scale=0.5,
			axis=0)

	pipeline += IntensityScaleShift(raw_key, 2,-1)

	pipeline += PreCache(
			cache_size=8,
			num_workers=4)

	train_inputs = {
		config['raw']: raw_key,
		config['gt_affs_in']: gt_affs_in_key,
		config['gt_affs_out']: gt_affs_out_key,
		config['pred_affs_loss_weights']: gt_affs_scale_key
	}

	if phase == 'euclid':
		train_loss = config['loss']
		train_optimizer = config['optimizer']
		train_summary = config['summary']
	else:
		train_loss = None
		train_optimizer = add_malis_loss
		train_inputs['gt_seg:0'] = picked_labels_key
		train_inputs['gt_affs_mask:0'] = gt_affs_mask_key
		train_summary = 'Merge/MergeSummary:0'

	pipeline += Train(
			graph=setup_dir + 'train_net',
			optimizer=train_optimizer,
			loss=train_loss,
			inputs=train_inputs,
			outputs={
				config['pred_affs']: pred_affs_key,
				config['broadcast']: broadcast_key,
				config['sample_z']: sample_z_key,
				config['pred_logits']: pred_logits_key,
				config['sample_out']: sample_out_key,
				config['debug']: debug_key
			},
			gradients={
				config['pred_affs']: pred_affs_gradient_key
			},
			summary=train_summary,
			log_dir='log/prob_unet/' + setup_name,
			save_every=2000)

	pipeline += IntensityScaleShift(
			array=raw_key,
			scale=0.5,
			shift=0.5)

	pipeline += Snapshot(
			dataset_names={
				labels_key: 'volumes/labels',
				picked_labels_key: 'volumes/merged_labels',
				raw_affs_key: 'volumes/raw_affs',
				raw_key: 'volumes/raw',
				gt_affs_in_key: 'volumes/gt_affs_in',
				gt_affs_out_key: 'volumes/gt_affs_out',
				pred_affs_key: 'volumes/pred_affs',
				pred_logits_key: 'volumes/pred_logits'
			},
			output_filename='prob_unet/' + setup_name + '/batch_{iteration}.hdf',
			every=2000,
			dataset_dtypes={
				labels_key: np.uint16,
				picked_labels_key: np.uint16,
				raw_affs_key: np.float32,
				raw_key: np.float32,
				gt_affs_in_key: np.float32,
				gt_affs_out_key: np.float32,
				pred_affs_key: np.float32,
				pred_logits_key: np.float32
			})

	pipeline += PrintProfilingStats(every=20)

	print("Starting training...")
	with build(pipeline) as p:
		for i in range(iterations - trained_until):
			req = p.request_batch(request)
			# sample_z = req[sample_z_key].data
			# broadcast_sample = req[broadcast_key].data
			# sample_out = req[sample_out_key].data
			# debug = req[debug_key].data
			# print("debug", debug)

			# print("sample_z: ", sample_z)
			# print("sample_out:", sample_out)
			# print("Z - 0")
			# print(np.unique(broadcast_sample[0, 0, :, :, :]))
			# print("Z - 1")
			# print(np.unique(broadcast_sample[0, 1, :, :, :]))
			# print("Z - 2")
			# print(np.unique(broadcast_sample[0, 2, :, :, :]))
			# print("Z - 3")
			# print(np.unique(broadcast_sample[0, 3, :, :, :]))
			# print("Z - 4")
			# print(np.unique(broadcast_sample[0, 4, :, :, :]))
			# print("Z - 5")
			# print(np.unique(broadcast_sample[0, 5, :, :, :]))
	print("Training finished")
def generate_data(num_batches):

    labels_key = ArrayKey('LABELS')
    gt_affs_key = ArrayKey('GT_AFFINITIES')
    joined_affs_key = ArrayKey('JOINED_AFFINITIES')
    raw_key1 = ArrayKey('RAW1')
    raw_key2 = ArrayKey('RAW2')
    raw_key3 = ArrayKey('RAW3')

    voxel_size = Coordinate((1, 1, 1))
    input_size = Coordinate((132, 132, 132)) * voxel_size
    output_size = Coordinate((44, 44, 44)) * voxel_size

    print("input_size: ", input_size)
    print("output_size: ", output_size)

    request = BatchRequest()
    request.add(labels_key, input_size)
    request.add(gt_affs_key, input_size)
    request.add(joined_affs_key, input_size)
    request.add(raw_key1, input_size)
    request.add(raw_key2, input_size)
    request.add(raw_key3, input_size)

    pipeline = (
        ToyNeuronSegmentationGenerator(array_key=labels_key,
                                       n_objects=20,
                                       points_per_skeleton=8,
                                       smoothness=3,
                                       noise_strength=1,
                                       interpolation="random",
                                       seed=None) +
        AddAffinities(affinity_neighborhood=[[-1, 0, 0], [0, -1, 0],
                                             [0, 0, -1]],
                      labels=labels_key,
                      affinities=gt_affs_key) +
        AddJoinedAffinities(input_affinities=gt_affs_key,
                            joined_affinities=joined_affs_key) +
        AddRealism(joined_affinities=joined_affs_key,
                   raw=raw_key1,
                   sp=0.25,
                   sigma=0,
                   contrast=1) + AddRealism(joined_affinities=joined_affs_key,
                                            raw=raw_key2,
                                            sp=0.25,
                                            sigma=1,
                                            contrast=1) +
        AddRealism(joined_affinities=joined_affs_key,
                   raw=raw_key3,
                   sp=0.25,
                   sigma=1,
                   contrast=0.7) +
        Snapshot(
            dataset_names={
                labels_key: 'volumes/labels',
                # gt_affs_key: 'volumes/gt_affs',
                # joined_affs_key: 'volumes/joined_affs',
                # raw_key1: 'volumes/raw1',
                # raw_key2: 'volumes/raw2',
                # raw_key3: 'volumes/raw3',
            },
            output_filename="results/data_gen/seg_standard.hdf",
            every=1,
            dataset_dtypes={
                labels_key: np.uint64,
                # raw_key1: np.float32,
                # raw_key2: np.float32,
                # raw_key3: np.float32,
                # gt_affs_key: np.float32,
                # joined_affs_key: np.float32
            }))

    hashes = []
    with build(pipeline) as p:
        for i in range(num_batches):
            print("\nDATA POINT: ", i)
            req = p.request_batch(request)
            labels = req[labels_key].data
            hashes = np.sum(labels)
            print(hashes)
Ejemplo n.º 4
0
def predict(checkpoint, iterations):

    print("checkpoint: ", checkpoint)

    labels_key = ArrayKey('GT_LABELS')
    joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
    raw_affinities_key = ArrayKey('RAW_AFFINITIES_KEY')
    raw_key = ArrayKey('RAW')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size

    print("input_size: ", input_shape)
    print("output_size: ", output_shape)

    request = BatchRequest()
    # request.add(labels_key, input_shape)
    request.add(joined_affinities_key, input_shape)
    request.add(raw_affinities_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(pred_affinities_key, output_shape)

    pipeline = (
        ToyNeuronSegmentationGenerator(array_key=labels_key,
                                       n_objects=50,
                                       points_per_skeleton=8,
                                       smoothness=3,
                                       noise_strength=1,
                                       interpolation="random",
                                       seed=0) +
        AddAffinities(affinity_neighborhood=neighborhood,
                      labels=labels_key,
                      affinities=raw_affinities_key) +
        AddJoinedAffinities(input_affinities=raw_affinities_key,
                            joined_affinities=joined_affinities_key) +
        AddRealism(joined_affinities=joined_affinities_key,
                   raw=raw_key,
                   sp=0.65,
                   sigma=1,
                   contrast=0.7) +
        # Pad(raw_key, size=None) +
        # Crop(raw_key, read_roi) +
        # Normalize(raw_key) +
        IntensityScaleShift(raw_key, 2, -1) +
        Predict(checkpoint=os.path.join(setup_dir, 'train_net_checkpoint_%d' %
                                        checkpoint),
                inputs={config['raw']: raw_key},
                outputs={config['pred_affs']: pred_affinities_key},
                graph=os.path.join(setup_dir, 'predict_net.meta')) +
        IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5) +
        Snapshot(dataset_names={
            labels_key: 'volumes/labels',
            raw_affinities_key: 'volumes/raw_affs',
            raw_key: 'volumes/raw',
            pred_affinities_key: 'volumes/pred_affs',
        },
                 output_filename='prob_unet/prediction_{id}.hdf',
                 every=1,
                 dataset_dtypes={
                     labels_key: np.uint16,
                     raw_key: np.float32,
                     pred_affinities_key: np.float32,
                     sample_z_key: np.float32
                 })
        # PrintProfilingStats(every=1)
    )

    print("Starting prediction...")
    with build(pipeline) as p:
        for i in range(iterations):
            p.request_batch(request)
    print("Prediction finished")
def generate_data(num_batches):

    neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]
    neighborhood_opp = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]

    # define array-keys
    labels_key = ArrayKey('LABELS')

    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    affs_key = ArrayKey('AFFINITIES')
    affs_opp_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_opp_key = ArrayKey('JOINED_AFFINITIES_OPP')
    merged_labels_key = ArrayKey('MERGED_LABELS')

    gt_affs_key = ArrayKey('GT_AFFINITIES')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate((132, 132, 132)) * voxel_size
    input_affs_shape = Coordinate([i + 1
                                   for i in (132, 132, 132)]) * voxel_size
    output_shape = Coordinate((44, 44, 44)) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in (44, 44, 44)]) * voxel_size

    print("input_shape: ", input_shape)
    print("output_shape: ", output_shape)

    request = BatchRequest()
    request.add(labels_key, output_shape)

    request.add(raw_key, input_shape)
    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)

    request.add(affs_key, input_affs_shape)
    request.add(affs_opp_key, input_affs_shape)
    request.add(joined_affs_key, input_affs_shape)
    request.add(joined_affs_opp_key, input_affs_shape)
    request.add(merged_labels_key, output_shape)

    request.add(gt_affs_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(gt_affs_mask_key, output_shape)

    # offset = Coordinate((input_size[i]-output_size[i])/2 for i in range(len(input_size)))
    # crop_roi = Roi(offset, output_size)
    # print("crop_roi: ", crop_roi)

    # print ("input_affinities_key: ", input_affinities_key)

    pipeline = ()
    # print ("iteration: ", iteration)

    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=15,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=raw_affs_key,
                                    joined_affinities=raw_joined_affs_key)

    pipeline += AddRealism(joined_affinities=raw_joined_affs_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=affs_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood_opp,
                              labels=labels_key,
                              affinities=affs_opp_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_key,
                                    joined_affinities=joined_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_opp_key,
                                    joined_affinities=joined_affs_opp_key)

    pipeline += MergeLabels(labels=labels_key,
                            joined_affinities=joined_affs_key,
                            joined_affinities_opp=joined_affs_opp_key,
                            merged_labels=merged_labels_key,
                            every=2)

    # # pipeline += GrowBoundary(merged_labels_key, steps=1, only_xy=True)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=merged_labels_key,
                              affinities=gt_affs_in_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=merged_labels_key,
                              affinities=gt_affs_key,
                              affinities_mask=gt_affs_mask_key)

    pipeline += PreCache(cache_size=4, num_workers=2)

    pipeline += Snapshot(dataset_names={
        labels_key: 'volumes/labels',
        merged_labels_key: 'volumes/merged_labels',
        raw_key: 'volumes/raw',
        raw_affs_key: 'volumes/raw_affs',
        gt_affs_key: 'volumes/gt_affs',
        gt_affs_in_key: 'volumes/gt_affs_in'
    },
                         output_filename='test_edges.hdf',
                         every=1,
                         dataset_dtypes={
                             merged_labels_key: np.uint64,
                             labels_key: np.uint64,
                             raw_key: np.float32
                         })

    pipeline += PrintProfilingStats(every=100)

    hashes = []
    with build(pipeline) as p:
        for i in range(num_batches):
            print("iteration: ", i)
            req = p.request_batch(request)
Ejemplo n.º 6
0
def train(iterations):

    # tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    if trained_until < phase_switch and iterations > phase_switch:
        train(phase_switch)

    phase = 'euclid' if iterations <= phase_switch else 'malis'
    print("Training in phase %s until %i" % (phase, iterations))

    # define array-keys
    labels_key = ArrayKey('GT_LABELS')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_out_key = ArrayKey('GT_AFFINITIES_OUT')
    joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')
    input_affinities_scale_key = ArrayKey('GT_AFFINITIES_SCALE')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')
    pred_affinities_gradient_key = ArrayKey('AFFS_GRADIENT')
    gt_affs_mask = ArrayKey('GT_AFFINITIES_MASK')

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size

    print("input_shape: ", input_shape)
    print("output_shape: ", output_shape)

    # define requests
    request = BatchRequest()
    request.add(labels_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(joined_affinities_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(gt_affs_out_key, output_shape)
    request.add(pred_affinities_key, output_shape)
    request.add(gt_affs_mask, output_shape)
    # if phase == 'euclid':
    request.add(input_affinities_scale_key, output_shape)

    offset = Coordinate((input_shape[i] - output_shape[i]) / 2
                        for i in range(len(input_shape)))
    crop_roi = Roi(offset, output_shape)
    # print("crop_roi: ", crop_roi)

    pipeline = ()

    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_in_key)

    pipeline += GrowBoundary(labels_key, steps=1, only_xy=True)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_out_key,
                              affinities_mask=gt_affs_mask)

    pipeline += AddJoinedAffinities(input_affinities=gt_affs_in_key,
                                    joined_affinities=joined_affinities_key)

    pipeline += AddRealism(joined_affinities=joined_affinities_key,
                           raw=raw_key,
                           sp=0.65,
                           sigma=1,
                           contrast=0.7)

    # if phase == 'euclid':
    pipeline += BalanceLabels(labels=gt_affs_out_key,
                              scales=input_affinities_scale_key)

    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)

    pipeline += IntensityScaleShift(raw_key, 2, -1)

    if phase == 'malis':
        pipeline += Crop(key=labels_key, roi=crop_roi)
        pipeline += RenumberConnectedComponents(labels=labels_key)

    pipeline += PreCache(cache_size=32, num_workers=8)

    train_inputs = {
        config['raw']: raw_key,
        config['gt_affs_in']: gt_affs_in_key,
        config['gt_affs_out']: gt_affs_out_key,
        config['pred_affs_loss_weights']: input_affinities_scale_key
    }

    if phase == 'euclid':
        train_loss = config['loss']
        train_optimizer = config['optimizer']
        train_summary = config['summary']
        # train_inputs[config['pred_affs_loss_weights']] = input_affinities_scale_key
    else:
        train_loss = None
        train_optimizer = add_malis_loss
        train_inputs['gt_seg:0'] = labels_key
        train_inputs['gt_affs_mask:0'] = gt_affs_mask
        train_summary = 'Merge/MergeSummary:0'

    pipeline += Train(
        graph=setup_dir + 'train_net',
        optimizer=train_optimizer,
        loss=train_loss,
        inputs=train_inputs,
        outputs={config['pred_affs']: pred_affinities_key},
        gradients={config['pred_affs']: pred_affinities_gradient_key},
        summary=train_summary,
        log_dir='log/prob_unet/' + setup_name,
        save_every=2000)

    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)

    pipeline += Snapshot(dataset_names={
        labels_key: 'volumes/labels',
        raw_key: 'volumes/raw',
        pred_affinities_key: 'volumes/pred_affs',
        gt_affs_out_key: 'volumes/output_affs'
    },
                         output_filename='prob_unet/' + setup_name +
                         '/batch_{iteration}.hdf',
                         every=4000,
                         dataset_dtypes={
                             labels_key: np.uint64,
                             raw_key: np.float32
                         })

    pipeline += PrintProfilingStats(every=20)

    print("Starting training...")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
            # print ("labels: ", req[labels_key].data.shape)
            # print ("affinities_in: ", req[gt_affs_in_key].data.shape)
            # print ("affinities_out: ", req[gt_affs_out_key].data.shape)
            # print ("affinities_joined: ", req[joined_affinities_key].data.shape)
            # print ("raw: ", req[raw_key].data.shape)
            # print ("affinities_in_scale: ", req[input_affinities_scale_key].data.shape)
    print("Training finished")
Ejemplo n.º 7
0
def predict(checkpoint, iterations):
    print("iterations:", iterations)

    labels_key = ArrayKey('GT_LABELS')
    joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
    raw_affinities_key = ArrayKey('RAW_AFFINITIES_KEY')
    raw_key = ArrayKey('RAW')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')
    debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    debug_shape = Coordinate((1, 1, 1)) * voxel_size

    print("input_size: ", input_shape)
    print("output_size: ", output_shape)

    request = BatchRequest()
    # request.add(labels_key, input_shape) # TODO: why does adding this request cause a duplication of generations?
    request.add(joined_affinities_key, input_shape)
    request.add(raw_affinities_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(pred_affinities_key, output_shape)
    request.add(debug_key, debug_shape)

    pipeline = ()

    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=7,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="random")
    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affinities_key)
    pipeline += AddJoinedAffinities(input_affinities=raw_affinities_key,
                                    joined_affinities=joined_affinities_key)
    pipeline += AddRealism(joined_affinities=joined_affinities_key,
                           raw=raw_key,
                           sp=0.65,
                           sigma=1,
                           contrast=0.7)
    # Pad(raw_key, size=None) +
    # Crop(raw_key, read_roi) +
    # Normalize(raw_key) +
    pipeline += IntensityScaleShift(raw_key, 2, -1)
    predict = Predict(
        checkpoint=os.path.join(setup_dir,
                                'train_net_checkpoint_%d' % checkpoint),
        inputs={config['raw']: raw_key},
        outputs={
            config['pred_affs']: pred_affinities_key,
            config['debug']: debug_key,
        },
        # graph=os.path.join(setup_dir, 'predict_net.meta')
    )

    pipeline += predict
    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)
    pipeline += Snapshot(dataset_names={
        labels_key: 'volumes/labels/labels',
        raw_affinities_key: 'volumes/raw_affs',
        raw_key: 'volumes/raw',
        pred_affinities_key: 'volumes/pred_affs'
    },
                         output_filename='unet/prediction.hdf',
                         every=1,
                         dataset_dtypes={
                             raw_key: np.float32,
                             pred_affinities_key: np.float32,
                             labels_key: np.uint64
                         })
    pipeline += PrintProfilingStats(every=20)

    print("Starting prediction...")
    with build(pipeline) as p:
        for i in range(iterations):
            print("iteration: ", i)
            req = p.request_batch(request)
            debug = req[debug_key].data
            # pred_affs = req[pred_affinities_key].data
            print("debug", debug)
            # print("pred_affs", pred_affs)

            # with predict.session.as_default():
            # 	d = predict.graph.get_tensor_by_name('debug:0')
            # 	print(d.eval())

    print("Prediction finished")
Ejemplo n.º 8
0
def train(iterations):
	tf.reset_default_graph()
	if tf.train.latest_checkpoint('.'):
		trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
	else:
		trained_until = 0
	if trained_until >= iterations:
		return
		
	# define array-keys
	labels_key = ArrayKey('GT_LABELS')
	gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
	gt_affs_out_key = ArrayKey('GT_AFFINITIES_OUT')
	joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
	raw_key = ArrayKey('RAW')
	input_affinities_scale_key = ArrayKey('GT_AFFINITIES_SCALE')
	pred_affinities_key = ArrayKey('PREDICTED_AFFS')
	pred_affinities_gradient_key = ArrayKey('AFFS_GRADIENT')
	gt_affs_mask = ArrayKey('GT_AFFINITIES_MASK')

	voxel_size = Coordinate((1, 1, 1))
	input_shape = Coordinate(config['input_shape']) * voxel_size
	output_shape = Coordinate(config['output_shape']) * voxel_size

	print ("input_shape: ", input_shape)
	print ("output_shape: ", output_shape)

	# define requests
	request = BatchRequest()
	request.add(labels_key, output_shape)
	request.add(gt_affs_in_key, input_shape)
	request.add(joined_affinities_key, input_shape)
	request.add(raw_key, input_shape)
	request.add(gt_affs_out_key, output_shape)
	request.add(input_affinities_scale_key, output_shape)
	request.add(pred_affinities_key, output_shape)
	request.add(gt_affs_mask, output_shape)

	offset = Coordinate((input_shape[i]-output_shape[i])/2 for i in range(len(input_shape)))
	crop_roi = Roi(offset, output_shape)
	# print("crop_roi: ", crop_roi)

	pipeline = (
		ToyNeuronSegmentationGenerator(
			array_key=labels_key,
			n_objects=50,
			points_per_skeleton=8,
			smoothness=3,
			noise_strength=1,
			interpolation="linear") + 
		# ElasticAugment(
		# 	control_point_spacing=[4,40,40],
		# 	jitter_sigma=[0,2,2],
		# 	rotation_interval=[0,math.pi/2.0],
		# 	prob_slip=0.05,
		# 	prob_shift=0.05,
		# 	max_misalign=10,
		# 	subsample=8) +
		# SimpleAugment(transpose_only=[1, 2]) +
		# IntensityAugment(labels, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
		AddAffinities(
            affinity_neighborhood=neighborhood,
            labels=labels_key,
            affinities=gt_affs_in_key) +
		GrowBoundary(labels_key, steps=1, only_xy=True) +
        AddAffinities(
            affinity_neighborhood=neighborhood,
            labels=labels_key,
            affinities=gt_affs_out_key,
            affinities_mask=gt_affs_mask) +
		AddJoinedAffinities(
			input_affinities=gt_affs_in_key,
			joined_affinities=joined_affinities_key) +
		AddRealism(
			joined_affinities = joined_affinities_key,
			raw = raw_key,
			sp=0.25,
			sigma=1,
			contrast=0.7) +
		BalanceLabels(
			labels=gt_affs_out_key,
			scales=input_affinities_scale_key) +
		DefectAugment(
			intensities=raw_key,
			prob_missing=0.03,
			prob_low_contrast=0.01,
			contrast_scale=0.5,
			axis=0) +
		IntensityScaleShift(raw_key, 2,-1) +
		Crop(
			key=labels_key,
			roi=crop_roi) +
		RenumberConnectedComponents(labels=labels_key) +
		PreCache(
			cache_size=32,
			num_workers=8) +
		Train(
			graph='train/prob_unet/setup_1/train_net',
			optimizer=add_loss,
			loss=None,
			inputs={
				"gt_seg:0": labels_key,
				"gt_affs_mask:0": gt_affs_mask,
				config['raw']: raw_key,
				config['gt_affs_in']: gt_affs_in_key,
				config['gt_affs_out']: gt_affs_out_key,
				config['pred_affs_loss_weights']: input_affinities_scale_key
			},
			outputs={
				config['pred_affs']: pred_affinities_key
			},
			gradients={
				config['pred_affs']: pred_affinities_gradient_key
			},
			summary="Merge/MergeSummary:0",
			log_dir='log/prob_unet/setup_2',
			save_every=500) +
		IntensityScaleShift(
			array=raw_key,
			scale=0.5,
			shift=0.5) +
		Snapshot(
			dataset_names={
				labels_key: 'volumes/labels',
				raw_key: 'volumes/raw',
				pred_affinities_key: 'volumes/pred_affs',
				gt_affs_out_key: 'volumes/output_affs'
			},
			output_filename='prob_unet/setup_2/batch_{iteration}.hdf',
			every=1000,
			dataset_dtypes={
				labels_key: np.uint64,
				raw_key: np.float32
			}) +
		PrintProfilingStats(every=20)
	)

	print("Starting training...")
	with build(pipeline) as p:
		for i in range(iterations - trained_until):
			req = p.request_batch(request)
			# print ("labels: ", req[labels_key].data.shape)
			# print ("affinities_in: ", req[gt_affs_in_key].data.shape)
			# print ("affinities_out: ", req[gt_affs_out_key].data.shape)
			# print ("affinities_joined: ", req[joined_affinities_key].data.shape)
			# print ("raw: ", req[raw_key].data.shape)
			# print ("affinities_in_scale: ", req[input_affinities_scale_key].data.shape)
	print("Training finished")
Ejemplo n.º 9
0
def train(iterations):

    # tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    if trained_until < phase_switch and iterations > phase_switch:
        train(phase_switch)

    phase = 'euclid' if iterations <= phase_switch else 'malis'
    print("Training in phase %s until %i" % (phase, iterations))

    # define array-keys
    labels_key = ArrayKey('LABELS')
    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    affs_neg_key = ArrayKey('AFFINITIES')
    affs_pos_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_neg_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_pos_key = ArrayKey('JOINED_AFFINITIES_OPP')

    gt_affs_out_key = ArrayKey('GT_AFFINITIES')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')
    gt_affs_scale_key = ArrayKey('GT_AFFINITIES_SCALE')

    pred_affs_key = ArrayKey('PRED_AFFS')
    pred_affs_gradient_key = ArrayKey('PRED_AFFS_GRADIENT')

    sample_z_key = ArrayKey("SAMPLE_Z")
    broadcast_key = ArrayKey("BROADCAST")
    pred_logits_key = ArrayKey("PRED_LOGITS")
    sample_out_key = ArrayKey("SAMPLE_OUT")
    debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    input_affs_shape = Coordinate([i + 1 for i in config['input_shape']
                                   ]) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in config['output_shape']
                                    ]) * voxel_size
    sample_shape = Coordinate((1, 1, config['latent_dims'])) * voxel_size
    debug_shape = Coordinate((1, 1, 5)) * voxel_size

    print("input_shape: ", input_shape)
    print("input_affs_shape: ", input_affs_shape)
    print("output_shape: ", output_shape)
    print("output_affs_shape: ", output_affs_shape)

    request = BatchRequest()
    request.add(labels_key, input_shape)

    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)
    request.add(raw_key, input_shape)

    request.add(gt_affs_out_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(gt_affs_mask_key, output_shape)
    request.add(gt_affs_scale_key, output_shape)

    request.add(pred_affs_key, output_shape)
    request.add(pred_affs_gradient_key, output_shape)

    request.add(broadcast_key, output_shape)
    request.add(sample_z_key, sample_shape)
    request.add(pred_logits_key, output_shape)
    request.add(sample_out_key, sample_shape)
    request.add(debug_key, debug_shape)

    pipeline = ()
    # print ("iteration: ", iteration)
    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=raw_affs_key,
                                    joined_affinities=raw_joined_affs_key)

    pipeline += AddRealism(joined_affinities=raw_joined_affs_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)

    pipeline += RenumberConnectedComponents(labels=labels_key)

    pipeline += GrowBoundary(labels_key, steps=1, only_xy=True)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_in_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_out_key,
                              affinities_mask=gt_affs_mask_key)

    pipeline += BalanceLabels(labels=gt_affs_out_key, scales=gt_affs_scale_key)

    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)

    pipeline += IntensityScaleShift(raw_key, 2, -1)

    pipeline += PreCache(cache_size=8, num_workers=4)

    train_inputs = {
        config['raw']: raw_key,
        config['gt_affs_in']: gt_affs_in_key,
        config['gt_affs_out']: gt_affs_out_key,
        config['pred_affs_loss_weights']: gt_affs_scale_key
    }

    if phase == 'euclid':
        train_loss = config['loss']
        train_optimizer = config['optimizer']
        train_summary = config['summary']
    else:
        train_loss = None
        train_optimizer = add_malis_loss
        train_inputs['gt_seg:0'] = labels_key
        train_inputs['gt_affs_mask:0'] = gt_affs_mask_key
        train_summary = 'Merge/MergeSummary:0'

    pipeline += Train(graph=setup_dir + 'train_net',
                      optimizer=train_optimizer,
                      loss=train_loss,
                      inputs=train_inputs,
                      outputs={
                          config['pred_affs']: pred_affs_key,
                          config['broadcast']: broadcast_key,
                          config['sample_z']: sample_z_key,
                          config['pred_logits']: pred_logits_key,
                          config['sample_out']: sample_out_key,
                          config['debug']: debug_key
                      },
                      gradients={config['pred_affs']: pred_affs_gradient_key},
                      summary=train_summary,
                      log_dir='log/prob_unet/' + setup_name,
                      save_every=2000)

    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)

    pipeline += Snapshot(dataset_names={
        labels_key: 'volumes/labels',
        raw_affs_key: 'volumes/raw_affs',
        raw_key: 'volumes/raw',
        gt_affs_in_key: 'volumes/gt_affs_in',
        gt_affs_out_key: 'volumes/gt_affs_out',
        pred_affs_key: 'volumes/pred_affs',
        pred_logits_key: 'volumes/pred_logits'
    },
                         output_filename='prob_unet/' + setup_name +
                         '/batch_{iteration}.hdf',
                         every=2000,
                         dataset_dtypes={
                             labels_key: np.uint16,
                             raw_affs_key: np.float32,
                             raw_key: np.float32,
                             gt_affs_in_key: np.float32,
                             gt_affs_out_key: np.float32,
                             pred_affs_key: np.float32,
                             pred_logits_key: np.float32
                         })

    pipeline += PrintProfilingStats(every=20)

    print("Starting training...")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
            # sample_z = req[sample_z_key].data
            # broadcast_sample = req[broadcast_key].data
            # sample_out = req[sample_out_key].data
            # debug = req[debug_key].data
            # print("debug", debug)

            # print("sample_z: ", sample_z)
            # print("sample_out:", sample_out)
            # print("Z - 0")
            # print(np.unique(broadcast_sample[0, 0, :, :, :]))
            # print("Z - 1")
            # print(np.unique(broadcast_sample[0, 1, :, :, :]))
            # print("Z - 2")
            # print(np.unique(broadcast_sample[0, 2, :, :, :]))
            # print("Z - 3")
            # print(np.unique(broadcast_sample[0, 3, :, :, :]))
            # print("Z - 4")
            # print(np.unique(broadcast_sample[0, 4, :, :, :]))
            # print("Z - 5")
            # print(np.unique(broadcast_sample[0, 5, :, :, :]))
    print("Training finished")
Ejemplo n.º 10
0
def train(iterations):
    tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    # define array-keys
    labels_key = ArrayKey('GT_LABELS')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_out_key = ArrayKey('GT_AFFINITIES_OUT')
    joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')
    input_affinities_scale_key = ArrayKey('GT_AFFINITIES_SCALE')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')
    pred_affinities_gradient_key = ArrayKey('AFFS_GRADIENT')
    gt_affs_mask = ArrayKey('GT_AFFINITIES_MASK')
    debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    debug_shape = Coordinate((1, 1, 5)) * voxel_size

    print("input_shape: ", input_shape)
    print("output_shape: ", output_shape)

    # define requests
    request = BatchRequest()
    request.add(labels_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(joined_affinities_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(gt_affs_out_key, output_shape)
    request.add(input_affinities_scale_key, output_shape)
    request.add(pred_affinities_key, output_shape)
    request.add(gt_affs_mask, output_shape)
    request.add(debug_key, debug_shape)

    offset = Coordinate((input_shape[i] - output_shape[i]) / 2
                        for i in range(len(input_shape)))
    crop_roi = Roi(offset, output_shape)
    # print("crop_roi: ", crop_roi)

    pipeline = ()
    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")
    # ElasticAugment(
    # 	control_point_spacing=[4,40,40],
    # 	jitter_sigma=[0,2,2],
    # 	rotation_interval=[0,math.pi/2.0],
    # 	prob_slip=0.05,
    # 	prob_shift=0.05,
    # 	max_misalign=10,
    # 	subsample=8) +
    # SimpleAugment(transpose_only=[1, 2]) +
    # IntensityAugment(labels, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_in_key)
    pipeline += GrowBoundary(labels_key, steps=1, only_xy=True)
    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_out_key,
                              affinities_mask=gt_affs_mask)
    pipeline += AddJoinedAffinities(input_affinities=gt_affs_in_key,
                                    joined_affinities=joined_affinities_key)
    pipeline += AddRealism(joined_affinities=joined_affinities_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)
    pipeline += BalanceLabels(labels=gt_affs_out_key,
                              scales=input_affinities_scale_key)
    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)
    pipeline += IntensityScaleShift(raw_key, 2, -1)
    pipeline += PreCache(cache_size=32, num_workers=8)
    pipeline += Crop(key=labels_key, roi=crop_roi)
    pipeline += RenumberConnectedComponents(labels=labels_key)
    train = Train(
        graph='train/unet/train_net',
        # optimizer=config['optimizer'],
        optimizer=add_malis_loss,
        # loss=config['loss'],
        loss=None,
        inputs={
            config['raw']: raw_key,
            "gt_seg:0": labels_key,
            "gt_affs_mask:0": gt_affs_mask,
            config['gt_affs']: gt_affs_out_key,
            config['pred_affs_loss_weights']: input_affinities_scale_key,
        },
        outputs={
            config['pred_affs']: pred_affinities_key,
            config['debug']: debug_key,
        },
        gradients={config['pred_affs']: pred_affinities_gradient_key},
        summary="malis_loss:0",
        log_dir='log/unet',
        save_every=1)
    pipeline += train
    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)
    # Snapshot(
    # 	dataset_names={
    # 		labels_key: 'volumes/labels',
    # 		raw_key: 'volumes/raw',
    # 		gt_affs_out_key: 'volumes/gt_affs',
    # 		pred_affinities_key: 'volumes/pred_affs'
    # 	},
    # 	output_filename='unet/train/batch_{iteration}.hdf',
    # 	every=100,
    # 	dataset_dtypes={
    # 		raw_key: np.float32,
    # 		labels_key: np.uint64
    # 	}) +
    pipeline += PrintProfilingStats(every=8)

    print("Starting training... COOL BEANS")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
            pred_affs = req[pred_affinities_key].data
            debug = req[debug_key].data
            print("debug", debug)
            print('pred_affs', pred_affs)
            print("name of pred_adds: ", req[pred_affinities_key])
            # print("train session: ", train.session)
            # print ("all vars: ", [n.name for n in tf.get_default_graph().as_graph_def().node])
            # graph_def = tf.graph_util.convert_variables_to_constants(train.session, tf.get_default_graph().as_graph_def(), ["pred_affs:0".split(':')[0]])
            # print ("labels: ", req[labels_key].data.shape)
            # print ("affinities_out: ", req[gt_affs_out_key].data.shape)
            # print ("affinities_joined: ", req[joined_affinities_key].data.shape)
            # print ("raw: ", req[raw_key].data.shape)
            # print ("affinities_in_scale: ", req[input_affinities_scale_key].data.shape)
    print("Training finished")
Ejemplo n.º 11
0
def generate_data(num_batches):

    labels_key = ArrayKey('LABELS')
    gt_affs_key = ArrayKey('GT_AFFINITIES')
    joined_affs_key = ArrayKey('JOINED_AFFINITIES')
    raw_key1 = ArrayKey('RAW1')
    raw_key2 = ArrayKey('RAW2')
    raw_key3 = ArrayKey('RAW3')

    voxel_size = Coordinate((1, 1, 1))
    input_size = Coordinate((133, 133, 133)) * voxel_size
    affs_size = Coordinate((131, 131, 131)) * voxel_size
    output_size = Coordinate((44, 44, 44)) * voxel_size

    print("input_size: ", input_size)
    print("output_size: ", output_size)

    request = BatchRequest()
    request.add(labels_key, input_size)
    request.add(gt_affs_key, affs_size)
    request.add(joined_affs_key, affs_size)
    request.add(raw_key1, affs_size)

    pipeline = (
        Hdf5Source(os.path.join(data_dir, 'seg_standard.hdf'),
                   datasets={labels_key: "volumes/labels"},
                   array_specs={labels_key: ArraySpec(interpolatable=False)}) +
        Pad(labels_key, None) + AddAffinities(
            affinity_neighborhood=[[-1, 0, 0], [0, -1, 0], [0, 0, -1]],
            labels=labels_key,
            affinities=gt_affs_key) +
        AddJoinedAffinities(input_affinities=gt_affs_key,
                            joined_affinities=joined_affs_key) +
        AddRealism(joined_affinities=joined_affs_key,
                   raw=raw_key1,
                   sp=0.25,
                   sigma=1,
                   contrast=0.7) +
        Snapshot(
            dataset_names={
                raw_key1: 'volumes/raw',
                # gt_affs_key: 'volumes/gt_affs',
                # joined_affs_key: 'volumes/joined_affs',
                # raw_key1: 'volumes/raw1',
                # raw_key2: 'volumes/raw2',
                # raw_key3: 'volumes/raw3',
            },
            output_filename="results/data_gen/raw_synth/contrast_07.hdf",
            every=1,
            dataset_dtypes={
                # labels_key: np.uint64,
                raw_key1: np.float32,
                # raw_key2: np.float32,
                # raw_key3: np.float32,
                # gt_affs_key: np.float32,
                # joined_affs_key: np.float32
            }))

    hashes = []
    with build(pipeline) as p:
        for i in range(num_batches):
            print("\nDATA POINT: ", i)
            req = p.request_batch(request)
            labels = req[labels_key].data
            hashes = np.sum(labels)
            print(hashes)