def generate_data(num_batches):

    neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]
    neighborhood_opp = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]

    # define array-keys
    labels_key = ArrayKey('LABELS')

    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    affs_key = ArrayKey('AFFINITIES')
    affs_opp_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_opp_key = ArrayKey('JOINED_AFFINITIES_OPP')
    merged_labels_key = ArrayKey('MERGED_LABELS')

    gt_affs_key = ArrayKey('GT_AFFINITIES')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate((132, 132, 132)) * voxel_size
    input_affs_shape = Coordinate([i + 1
                                   for i in (132, 132, 132)]) * voxel_size
    output_shape = Coordinate((44, 44, 44)) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in (44, 44, 44)]) * voxel_size

    print("input_shape: ", input_shape)
    print("output_shape: ", output_shape)

    request = BatchRequest()
    request.add(labels_key, output_shape)

    request.add(raw_key, input_shape)
    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)

    request.add(affs_key, input_affs_shape)
    request.add(affs_opp_key, input_affs_shape)
    request.add(joined_affs_key, input_affs_shape)
    request.add(joined_affs_opp_key, input_affs_shape)
    request.add(merged_labels_key, output_shape)

    request.add(gt_affs_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(gt_affs_mask_key, output_shape)

    # offset = Coordinate((input_size[i]-output_size[i])/2 for i in range(len(input_size)))
    # crop_roi = Roi(offset, output_size)
    # print("crop_roi: ", crop_roi)

    # print ("input_affinities_key: ", input_affinities_key)

    pipeline = ()
    # print ("iteration: ", iteration)

    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=15,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=raw_affs_key,
                                    joined_affinities=raw_joined_affs_key)

    pipeline += AddRealism(joined_affinities=raw_joined_affs_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=affs_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood_opp,
                              labels=labels_key,
                              affinities=affs_opp_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_key,
                                    joined_affinities=joined_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_opp_key,
                                    joined_affinities=joined_affs_opp_key)

    pipeline += MergeLabels(labels=labels_key,
                            joined_affinities=joined_affs_key,
                            joined_affinities_opp=joined_affs_opp_key,
                            merged_labels=merged_labels_key,
                            every=2)

    # # pipeline += GrowBoundary(merged_labels_key, steps=1, only_xy=True)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=merged_labels_key,
                              affinities=gt_affs_in_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=merged_labels_key,
                              affinities=gt_affs_key,
                              affinities_mask=gt_affs_mask_key)

    pipeline += PreCache(cache_size=4, num_workers=2)

    pipeline += Snapshot(dataset_names={
        labels_key: 'volumes/labels',
        merged_labels_key: 'volumes/merged_labels',
        raw_key: 'volumes/raw',
        raw_affs_key: 'volumes/raw_affs',
        gt_affs_key: 'volumes/gt_affs',
        gt_affs_in_key: 'volumes/gt_affs_in'
    },
                         output_filename='test_edges.hdf',
                         every=1,
                         dataset_dtypes={
                             merged_labels_key: np.uint64,
                             labels_key: np.uint64,
                             raw_key: np.float32
                         })

    pipeline += PrintProfilingStats(every=100)

    hashes = []
    with build(pipeline) as p:
        for i in range(num_batches):
            print("iteration: ", i)
            req = p.request_batch(request)
Exemplo n.º 2
0
def predict(checkpoint, iterations):
    print("iterations:", iterations)

    labels_key = ArrayKey('GT_LABELS')
    joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
    raw_affinities_key = ArrayKey('RAW_AFFINITIES_KEY')
    raw_key = ArrayKey('RAW')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')
    debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    debug_shape = Coordinate((1, 1, 1)) * voxel_size

    print("input_size: ", input_shape)
    print("output_size: ", output_shape)

    request = BatchRequest()
    # request.add(labels_key, input_shape) # TODO: why does adding this request cause a duplication of generations?
    request.add(joined_affinities_key, input_shape)
    request.add(raw_affinities_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(pred_affinities_key, output_shape)
    request.add(debug_key, debug_shape)

    pipeline = ()

    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=7,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="random")
    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affinities_key)
    pipeline += AddJoinedAffinities(input_affinities=raw_affinities_key,
                                    joined_affinities=joined_affinities_key)
    pipeline += AddRealism(joined_affinities=joined_affinities_key,
                           raw=raw_key,
                           sp=0.65,
                           sigma=1,
                           contrast=0.7)
    # Pad(raw_key, size=None) +
    # Crop(raw_key, read_roi) +
    # Normalize(raw_key) +
    pipeline += IntensityScaleShift(raw_key, 2, -1)
    predict = Predict(
        checkpoint=os.path.join(setup_dir,
                                'train_net_checkpoint_%d' % checkpoint),
        inputs={config['raw']: raw_key},
        outputs={
            config['pred_affs']: pred_affinities_key,
            config['debug']: debug_key,
        },
        # graph=os.path.join(setup_dir, 'predict_net.meta')
    )

    pipeline += predict
    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)
    pipeline += Snapshot(dataset_names={
        labels_key: 'volumes/labels/labels',
        raw_affinities_key: 'volumes/raw_affs',
        raw_key: 'volumes/raw',
        pred_affinities_key: 'volumes/pred_affs'
    },
                         output_filename='unet/prediction.hdf',
                         every=1,
                         dataset_dtypes={
                             raw_key: np.float32,
                             pred_affinities_key: np.float32,
                             labels_key: np.uint64
                         })
    pipeline += PrintProfilingStats(every=20)

    print("Starting prediction...")
    with build(pipeline) as p:
        for i in range(iterations):
            print("iteration: ", i)
            req = p.request_batch(request)
            debug = req[debug_key].data
            # pred_affs = req[pred_affinities_key].data
            print("debug", debug)
            # print("pred_affs", pred_affs)

            # with predict.session.as_default():
            # 	d = predict.graph.get_tensor_by_name('debug:0')
            # 	print(d.eval())

    print("Prediction finished")
Exemplo n.º 3
0
def train(iterations):

    # tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    if trained_until < phase_switch and iterations > phase_switch:
        train(phase_switch)

    phase = 'euclid' if iterations <= phase_switch else 'malis'
    print("Training in phase %s until %i" % (phase, iterations))

    # define array-keys
    labels_key = ArrayKey('LABELS')

    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    affs_key = ArrayKey('AFFINITIES')
    affs_opp_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_opp_key = ArrayKey('JOINED_AFFINITIES_OPP')
    merged_labels_key = ArrayKey('MERGED_LABELS')

    gt_affs_key = ArrayKey('GT_AFFINITIES')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')
    gt_affs_scale_key = ArrayKey('GT_AFFINITIES_SCALE')

    pred_affs_key = ArrayKey('PRED_AFFS')
    pred_affs_gradient_key = ArrayKey('PRED_AFFS_GRADIENT')

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    input_affs_shape = Coordinate([i + 1 for i in config['input_shape']
                                   ]) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in config['output_shape']
                                    ]) * voxel_size

    print("input_shape: ", input_shape)
    print("input_affs_shape: ", input_affs_shape)
    print("output_shape: ", output_shape)
    print("output_affs_shape: ", output_affs_shape)

    request = BatchRequest()
    request.add(labels_key, output_shape)

    request.add(raw_key, input_shape)
    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)

    request.add(affs_key, input_affs_shape)
    request.add(affs_opp_key, input_affs_shape)
    request.add(joined_affs_key, input_affs_shape)
    request.add(joined_affs_opp_key, input_affs_shape)
    request.add(merged_labels_key, output_shape)

    request.add(gt_affs_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(gt_affs_mask_key, output_shape)
    request.add(gt_affs_scale_key, output_shape)

    request.add(pred_affs_key, output_shape)

    # offset = Coordinate((input_shape[i]-output_shape[i])/2 for i in range(len(input_shape)))
    # crop_roi = Roi(offset, output_shape)
    # print("crop_roi: ", crop_roi)

    pipeline = ()
    # print ("iteration: ", iteration)
    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=raw_affs_key,
                                    joined_affinities=raw_joined_affs_key)

    pipeline += AddRealism(joined_affinities=raw_joined_affs_key,
                           raw=raw_key,
                           sp=0.65,
                           sigma=1,
                           contrast=0.7)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=affs_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood_opp,
                              labels=labels_key,
                              affinities=affs_opp_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_key,
                                    joined_affinities=joined_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_opp_key,
                                    joined_affinities=joined_affs_opp_key)

    pipeline += MergeLabels(labels=labels_key,
                            joined_affinities=joined_affs_key,
                            joined_affinities_opp=joined_affs_opp_key,
                            merged_labels=merged_labels_key,
                            every=2)

    # pipeline += GrowBoundary(merged_labels_key, steps=1, only_xy=True)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=merged_labels_key,
                              affinities=gt_affs_in_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=merged_labels_key,
                              affinities=gt_affs_key,
                              affinities_mask=gt_affs_mask_key)

    # if phase == 'euclid':
    pipeline += BalanceLabels(labels=gt_affs_key, scales=gt_affs_scale_key)

    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)

    pipeline += IntensityScaleShift(raw_key, 2, -1)

    if phase == 'malis':
        # pipeline += Crop(
        # 	key=merged_labels_key,
        # 	roi=crop_roi)
        pipeline += RenumberConnectedComponents(labels=merged_labels_key)

    pipeline += PreCache(cache_size=32, num_workers=8)

    train_inputs = {
        config['raw']: raw_key,
        config['gt_affs_in']: gt_affs_in_key,
        config['gt_affs_out']: gt_affs_key,
        config['pred_affs_loss_weights']: gt_affs_scale_key
    }

    if phase == 'euclid':
        train_loss = config['loss']
        train_optimizer = config['optimizer']
        train_summary = config['summary']
        # train_inputs[config['pred_affs_loss_weights']] = input_affinities_scale_key
    else:
        train_loss = None
        train_optimizer = add_malis_loss
        train_inputs['gt_seg:0'] = merged_labels_key  # XXX question
        train_inputs['gt_affs_mask:0'] = gt_affs_mask_key
        train_summary = 'Merge/MergeSummary:0'

    pipeline += Train(graph=setup_dir + 'train_net',
                      optimizer=train_optimizer,
                      loss=train_loss,
                      inputs=train_inputs,
                      outputs={config['pred_affs']: pred_affs_key},
                      gradients={config['pred_affs']: pred_affs_gradient_key},
                      summary=train_summary,
                      log_dir='log/prob_unet/' + setup_name,
                      save_every=2000)

    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)

    pipeline += Snapshot(dataset_names={
        merged_labels_key: 'volumes/labels',
        raw_key: 'volumes/raw',
        pred_affs_key: 'volumes/pred_affs',
        gt_affs_key: 'volumes/gt_affs'
    },
                         output_filename='prob_unet/' + setup_name +
                         '/batch_{iteration}.hdf',
                         every=4000,
                         dataset_dtypes={
                             merged_labels_key: np.uint64,
                             raw_key: np.float32
                         })

    pipeline += PrintProfilingStats(every=20)

    print("Starting training...")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
    print("Training finished")
Exemplo n.º 4
0
def train(iterations):
	tf.reset_default_graph()
	if tf.train.latest_checkpoint('.'):
		trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
	else:
		trained_until = 0
	if trained_until >= iterations:
		return
		
	# define array-keys
	labels_key = ArrayKey('GT_LABELS')
	gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
	gt_affs_out_key = ArrayKey('GT_AFFINITIES_OUT')
	joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
	raw_key = ArrayKey('RAW')
	input_affinities_scale_key = ArrayKey('GT_AFFINITIES_SCALE')
	pred_affinities_key = ArrayKey('PREDICTED_AFFS')
	pred_affinities_gradient_key = ArrayKey('AFFS_GRADIENT')
	gt_affs_mask = ArrayKey('GT_AFFINITIES_MASK')

	voxel_size = Coordinate((1, 1, 1))
	input_shape = Coordinate(config['input_shape']) * voxel_size
	output_shape = Coordinate(config['output_shape']) * voxel_size

	print ("input_shape: ", input_shape)
	print ("output_shape: ", output_shape)

	# define requests
	request = BatchRequest()
	request.add(labels_key, output_shape)
	request.add(gt_affs_in_key, input_shape)
	request.add(joined_affinities_key, input_shape)
	request.add(raw_key, input_shape)
	request.add(gt_affs_out_key, output_shape)
	request.add(input_affinities_scale_key, output_shape)
	request.add(pred_affinities_key, output_shape)
	request.add(gt_affs_mask, output_shape)

	offset = Coordinate((input_shape[i]-output_shape[i])/2 for i in range(len(input_shape)))
	crop_roi = Roi(offset, output_shape)
	# print("crop_roi: ", crop_roi)

	pipeline = (
		ToyNeuronSegmentationGenerator(
			array_key=labels_key,
			n_objects=50,
			points_per_skeleton=8,
			smoothness=3,
			noise_strength=1,
			interpolation="linear") + 
		# ElasticAugment(
		# 	control_point_spacing=[4,40,40],
		# 	jitter_sigma=[0,2,2],
		# 	rotation_interval=[0,math.pi/2.0],
		# 	prob_slip=0.05,
		# 	prob_shift=0.05,
		# 	max_misalign=10,
		# 	subsample=8) +
		# SimpleAugment(transpose_only=[1, 2]) +
		# IntensityAugment(labels, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
		AddAffinities(
            affinity_neighborhood=neighborhood,
            labels=labels_key,
            affinities=gt_affs_in_key) +
		GrowBoundary(labels_key, steps=1, only_xy=True) +
        AddAffinities(
            affinity_neighborhood=neighborhood,
            labels=labels_key,
            affinities=gt_affs_out_key,
            affinities_mask=gt_affs_mask) +
		AddJoinedAffinities(
			input_affinities=gt_affs_in_key,
			joined_affinities=joined_affinities_key) +
		AddRealism(
			joined_affinities = joined_affinities_key,
			raw = raw_key,
			sp=0.25,
			sigma=1,
			contrast=0.7) +
		BalanceLabels(
			labels=gt_affs_out_key,
			scales=input_affinities_scale_key) +
		DefectAugment(
			intensities=raw_key,
			prob_missing=0.03,
			prob_low_contrast=0.01,
			contrast_scale=0.5,
			axis=0) +
		IntensityScaleShift(raw_key, 2,-1) +
		Crop(
			key=labels_key,
			roi=crop_roi) +
		RenumberConnectedComponents(labels=labels_key) +
		PreCache(
			cache_size=32,
			num_workers=8) +
		Train(
			graph='train/prob_unet/setup_1/train_net',
			optimizer=add_loss,
			loss=None,
			inputs={
				"gt_seg:0": labels_key,
				"gt_affs_mask:0": gt_affs_mask,
				config['raw']: raw_key,
				config['gt_affs_in']: gt_affs_in_key,
				config['gt_affs_out']: gt_affs_out_key,
				config['pred_affs_loss_weights']: input_affinities_scale_key
			},
			outputs={
				config['pred_affs']: pred_affinities_key
			},
			gradients={
				config['pred_affs']: pred_affinities_gradient_key
			},
			summary="Merge/MergeSummary:0",
			log_dir='log/prob_unet/setup_2',
			save_every=500) +
		IntensityScaleShift(
			array=raw_key,
			scale=0.5,
			shift=0.5) +
		Snapshot(
			dataset_names={
				labels_key: 'volumes/labels',
				raw_key: 'volumes/raw',
				pred_affinities_key: 'volumes/pred_affs',
				gt_affs_out_key: 'volumes/output_affs'
			},
			output_filename='prob_unet/setup_2/batch_{iteration}.hdf',
			every=1000,
			dataset_dtypes={
				labels_key: np.uint64,
				raw_key: np.float32
			}) +
		PrintProfilingStats(every=20)
	)

	print("Starting training...")
	with build(pipeline) as p:
		for i in range(iterations - trained_until):
			req = p.request_batch(request)
			# print ("labels: ", req[labels_key].data.shape)
			# print ("affinities_in: ", req[gt_affs_in_key].data.shape)
			# print ("affinities_out: ", req[gt_affs_out_key].data.shape)
			# print ("affinities_joined: ", req[joined_affinities_key].data.shape)
			# print ("raw: ", req[raw_key].data.shape)
			# print ("affinities_in_scale: ", req[input_affinities_scale_key].data.shape)
	print("Training finished")
def generate_full_samples(num_batches):

    neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]
    neighborhood_opp = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]

    # define array-keys
    labels_key = ArrayKey('LABELS')
    gt_affs_key = ArrayKey('RAW_AFFINITIES')
    gt_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    merged_labels_key = []
    merged_affs_key = []

    affs_neg_key = ArrayKey('AFFINITIES')
    affs_pos_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_neg_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_pos_key = ArrayKey('JOINED_AFFINITIES_OPP')

    num_merges = 3
    for i in range(num_merges):
        merged_labels_key.append(ArrayKey('MERGED_LABELS_%i' % (i + 1)))
        merged_affs_key.append(ArrayKey('MERGED_AFFINITIES_IN_%i' % (i + 1)))

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate((132, 132, 132)) * voxel_size
    input_affs_shape = Coordinate([i + 1
                                   for i in (132, 132, 132)]) * voxel_size
    output_shape = Coordinate((44, 44, 44)) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in (44, 44, 44)]) * voxel_size

    print("input_shape: ", input_shape)
    print("output_shape: ", output_shape)

    request = BatchRequest()
    request.add(labels_key, input_shape)

    # request.add(gt_affs_key, input_shape)
    # request.add(gt_joined_affs_key, input_shape)
    # request.add(raw_key, input_shape)

    request.add(affs_neg_key, input_shape)
    request.add(affs_pos_key, input_shape)
    request.add(joined_affs_neg_key, input_shape)
    request.add(joined_affs_pos_key, input_shape)

    for i in range(num_merges):
        request.add(merged_labels_key[i], input_shape)
        # request.add(merged_affs_key[i], input_shape)

    pipeline = ()

    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    # pipeline += AddAffinities(
    # 		affinity_neighborhood=neighborhood,
    # 		labels=labels_key,
    # 		affinities=gt_affs_key)

    # pipeline += AddJoinedAffinities(
    # 		input_affinities=gt_affs_key,
    # 		joined_affinities=gt_joined_affs_key)

    # pipeline += AddRealism(
    # 		joined_affinities = gt_joined_affs_key,
    # 		raw = raw_key,
    # 		sp=0.25,
    # 		sigma=1,
    # 		contrast=0.7)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=affs_neg_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood_opp,
                              labels=labels_key,
                              affinities=affs_pos_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_neg_key,
                                    joined_affinities=joined_affs_neg_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_pos_key,
                                    joined_affinities=joined_affs_pos_key)

    for i in range(num_merges):

        pipeline += MergeLabels(labels=labels_key,
                                joined_affinities=joined_affs_neg_key,
                                joined_affinities_opp=joined_affs_pos_key,
                                merged_labels=merged_labels_key[i],
                                cropped_roi=None,
                                every=1)

        # pipeline += AddAffinities(
        # 		affinity_neighborhood=neighborhood,
        # 		labels=merged_labels_key[i],
        # 		affinities=merged_affs_key[i])

    pipeline += PreCache(cache_size=32, num_workers=8)

    dataset_names = {
        labels_key: 'volumes/labels',
        # affs_neg_key: 'volumes/gt_affs'
    }

    dataset_dtypes = {
        labels_key: np.uint16,
        # affs_neg_key: np.float32
    }

    for i in range(num_merges):
        dataset_names[merged_labels_key[i]] = 'volumes/merged_labels_%i' % (i +
                                                                            1)
        # dataset_names[merged_affs_key[i]] = 'volumes/merged_affs_%i'%(i+1)
        dataset_dtypes[merged_labels_key[i]] = np.uint16
        # dataset_dtypes[merged_affs_key[i]] = np.float32

    pipeline += Snapshot(
        dataset_names=dataset_names,
        output_filename='gt_1_merge_3/batch_{id}.hdf',
        # output_filename='test_affs.hdf',
        every=1,
        dataset_dtypes=dataset_dtypes)

    pipeline += PrintProfilingStats(every=10)

    hashes = []
    with build(pipeline) as p:
        for i in range(num_batches):
            req = p.request_batch(request)
            label_hash = np.sum(req[labels_key].data)
Exemplo n.º 6
0
def predict(checkpoint, iterations):

	print ("checkpoint: ", checkpoint)

	labels_key = ArrayKey('GT_LABELS')
	joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
	raw_affinities_key = ArrayKey('RAW_AFFINITIES_KEY')
	raw_key = ArrayKey('RAW')
	pred_affinities_key = ArrayKey('PREDICTED_AFFS')

	voxel_size = Coordinate((1, 1, 1))
	input_shape = Coordinate(config['input_shape']) * voxel_size
	output_shape = Coordinate(config['output_shape']) * voxel_size

	print ("input_size: ", input_shape)
	print ("output_size: ", output_shape)

	request = BatchRequest()
	# request.add(labels_key, input_shape)
	request.add(joined_affinities_key, input_shape)
	request.add(raw_affinities_key, input_shape)
	request.add(raw_key, input_shape)
	request.add(pred_affinities_key, output_shape)

	pipeline = (
		ToyNeuronSegmentationGenerator(
			array_key=labels_key,
			n_objects=50,
			points_per_skeleton=8,
			smoothness=3,
			noise_strength=1,
			interpolation="random",
			seed=0) +
		AddAffinities(
			affinity_neighborhood=neighborhood,
			labels=labels_key,
			affinities=raw_affinities_key) +
		AddJoinedAffinities(
			input_affinities=raw_affinities_key,
			joined_affinities=joined_affinities_key) +
		AddRealism(
			joined_affinities = joined_affinities_key,
			raw = raw_key,
			sp=0.65,
			sigma=1,
			contrast=0.7) +
		# Pad(raw_key, size=None) +
		# Crop(raw_key, read_roi) +
		# Normalize(raw_key) +
		IntensityScaleShift(raw_key, 2,-1) +
		Predict(
			checkpoint = os.path.join(setup_dir, 'train_net_checkpoint_%d' % checkpoint),
			inputs={
				config['raw']: raw_key
			},
			outputs={
				config['pred_affs']: pred_affinities_key
			},
			graph=os.path.join(setup_dir, 'predict_net.meta')
		) +
		IntensityScaleShift(
			array=raw_key,
			scale=0.5,
			shift=0.5) +
		Snapshot(
			dataset_names={
				labels_key: 'volumes/labels',
				raw_affinities_key: 'volumes/raw_affs',
				raw_key: 'volumes/raw',
				pred_affinities_key: 'volumes/pred_affs',
			},
			output_filename='prob_unet/prediction_{id}.hdf',
			every=1,
			dataset_dtypes={
				labels_key: np.uint16,
				raw_key: np.float32,
				pred_affinities_key: np.float32,
				sample_z_key: np.float32
			})
		# PrintProfilingStats(every=1)
	)

	print("Starting prediction...")
	with build(pipeline) as p:
		for i in range(iterations):
			p.request_batch(request)
	print("Prediction finished")
Exemplo n.º 7
0
def train(iterations):

    # tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    if trained_until < phase_switch and iterations > phase_switch:
        train(phase_switch)

    phase = 'euclid' if iterations <= phase_switch else 'malis'
    print("Training in phase %s until %i" % (phase, iterations))

    # define array-keys
    labels_key = ArrayKey('LABELS')
    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    affs_neg_key = ArrayKey('AFFINITIES')
    affs_pos_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_neg_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_pos_key = ArrayKey('JOINED_AFFINITIES_OPP')

    gt_affs_out_key = ArrayKey('GT_AFFINITIES')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')
    gt_affs_scale_key = ArrayKey('GT_AFFINITIES_SCALE')

    pred_affs_key = ArrayKey('PRED_AFFS')
    pred_affs_gradient_key = ArrayKey('PRED_AFFS_GRADIENT')

    sample_z_key = ArrayKey("SAMPLE_Z")
    broadcast_key = ArrayKey("BROADCAST")
    pred_logits_key = ArrayKey("PRED_LOGITS")
    sample_out_key = ArrayKey("SAMPLE_OUT")
    debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    input_affs_shape = Coordinate([i + 1 for i in config['input_shape']
                                   ]) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in config['output_shape']
                                    ]) * voxel_size
    sample_shape = Coordinate((1, 1, config['latent_dims'])) * voxel_size
    debug_shape = Coordinate((1, 1, 5)) * voxel_size

    print("input_shape: ", input_shape)
    print("input_affs_shape: ", input_affs_shape)
    print("output_shape: ", output_shape)
    print("output_affs_shape: ", output_affs_shape)

    request = BatchRequest()
    request.add(labels_key, input_shape)

    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)
    request.add(raw_key, input_shape)

    request.add(gt_affs_out_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(gt_affs_mask_key, output_shape)
    request.add(gt_affs_scale_key, output_shape)

    request.add(pred_affs_key, output_shape)
    request.add(pred_affs_gradient_key, output_shape)

    request.add(broadcast_key, output_shape)
    request.add(sample_z_key, sample_shape)
    request.add(pred_logits_key, output_shape)
    request.add(sample_out_key, sample_shape)
    request.add(debug_key, debug_shape)

    pipeline = ()
    # print ("iteration: ", iteration)
    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=raw_affs_key,
                                    joined_affinities=raw_joined_affs_key)

    pipeline += AddRealism(joined_affinities=raw_joined_affs_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)

    pipeline += RenumberConnectedComponents(labels=labels_key)

    pipeline += GrowBoundary(labels_key, steps=1, only_xy=True)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_in_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_out_key,
                              affinities_mask=gt_affs_mask_key)

    pipeline += BalanceLabels(labels=gt_affs_out_key, scales=gt_affs_scale_key)

    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)

    pipeline += IntensityScaleShift(raw_key, 2, -1)

    pipeline += PreCache(cache_size=8, num_workers=4)

    train_inputs = {
        config['raw']: raw_key,
        config['gt_affs_in']: gt_affs_in_key,
        config['gt_affs_out']: gt_affs_out_key,
        config['pred_affs_loss_weights']: gt_affs_scale_key
    }

    if phase == 'euclid':
        train_loss = config['loss']
        train_optimizer = config['optimizer']
        train_summary = config['summary']
    else:
        train_loss = None
        train_optimizer = add_malis_loss
        train_inputs['gt_seg:0'] = labels_key
        train_inputs['gt_affs_mask:0'] = gt_affs_mask_key
        train_summary = 'Merge/MergeSummary:0'

    pipeline += Train(graph=setup_dir + 'train_net',
                      optimizer=train_optimizer,
                      loss=train_loss,
                      inputs=train_inputs,
                      outputs={
                          config['pred_affs']: pred_affs_key,
                          config['broadcast']: broadcast_key,
                          config['sample_z']: sample_z_key,
                          config['pred_logits']: pred_logits_key,
                          config['sample_out']: sample_out_key,
                          config['debug']: debug_key
                      },
                      gradients={config['pred_affs']: pred_affs_gradient_key},
                      summary=train_summary,
                      log_dir='log/prob_unet/' + setup_name,
                      save_every=2000)

    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)

    pipeline += Snapshot(dataset_names={
        labels_key: 'volumes/labels',
        raw_affs_key: 'volumes/raw_affs',
        raw_key: 'volumes/raw',
        gt_affs_in_key: 'volumes/gt_affs_in',
        gt_affs_out_key: 'volumes/gt_affs_out',
        pred_affs_key: 'volumes/pred_affs',
        pred_logits_key: 'volumes/pred_logits'
    },
                         output_filename='prob_unet/' + setup_name +
                         '/batch_{iteration}.hdf',
                         every=2000,
                         dataset_dtypes={
                             labels_key: np.uint16,
                             raw_affs_key: np.float32,
                             raw_key: np.float32,
                             gt_affs_in_key: np.float32,
                             gt_affs_out_key: np.float32,
                             pred_affs_key: np.float32,
                             pred_logits_key: np.float32
                         })

    pipeline += PrintProfilingStats(every=20)

    print("Starting training...")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
            # sample_z = req[sample_z_key].data
            # broadcast_sample = req[broadcast_key].data
            # sample_out = req[sample_out_key].data
            # debug = req[debug_key].data
            # print("debug", debug)

            # print("sample_z: ", sample_z)
            # print("sample_out:", sample_out)
            # print("Z - 0")
            # print(np.unique(broadcast_sample[0, 0, :, :, :]))
            # print("Z - 1")
            # print(np.unique(broadcast_sample[0, 1, :, :, :]))
            # print("Z - 2")
            # print(np.unique(broadcast_sample[0, 2, :, :, :]))
            # print("Z - 3")
            # print(np.unique(broadcast_sample[0, 3, :, :, :]))
            # print("Z - 4")
            # print(np.unique(broadcast_sample[0, 4, :, :, :]))
            # print("Z - 5")
            # print(np.unique(broadcast_sample[0, 5, :, :, :]))
    print("Training finished")
Exemplo n.º 8
0
def train(iterations):
    tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    # define array-keys
    labels_key = ArrayKey('GT_LABELS')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_out_key = ArrayKey('GT_AFFINITIES_OUT')
    joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')
    input_affinities_scale_key = ArrayKey('GT_AFFINITIES_SCALE')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')
    pred_affinities_gradient_key = ArrayKey('AFFS_GRADIENT')
    gt_affs_mask = ArrayKey('GT_AFFINITIES_MASK')
    debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    debug_shape = Coordinate((1, 1, 5)) * voxel_size

    print("input_shape: ", input_shape)
    print("output_shape: ", output_shape)

    # define requests
    request = BatchRequest()
    request.add(labels_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(joined_affinities_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(gt_affs_out_key, output_shape)
    request.add(input_affinities_scale_key, output_shape)
    request.add(pred_affinities_key, output_shape)
    request.add(gt_affs_mask, output_shape)
    request.add(debug_key, debug_shape)

    offset = Coordinate((input_shape[i] - output_shape[i]) / 2
                        for i in range(len(input_shape)))
    crop_roi = Roi(offset, output_shape)
    # print("crop_roi: ", crop_roi)

    pipeline = ()
    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")
    # ElasticAugment(
    # 	control_point_spacing=[4,40,40],
    # 	jitter_sigma=[0,2,2],
    # 	rotation_interval=[0,math.pi/2.0],
    # 	prob_slip=0.05,
    # 	prob_shift=0.05,
    # 	max_misalign=10,
    # 	subsample=8) +
    # SimpleAugment(transpose_only=[1, 2]) +
    # IntensityAugment(labels, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_in_key)
    pipeline += GrowBoundary(labels_key, steps=1, only_xy=True)
    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_out_key,
                              affinities_mask=gt_affs_mask)
    pipeline += AddJoinedAffinities(input_affinities=gt_affs_in_key,
                                    joined_affinities=joined_affinities_key)
    pipeline += AddRealism(joined_affinities=joined_affinities_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)
    pipeline += BalanceLabels(labels=gt_affs_out_key,
                              scales=input_affinities_scale_key)
    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)
    pipeline += IntensityScaleShift(raw_key, 2, -1)
    pipeline += PreCache(cache_size=32, num_workers=8)
    pipeline += Crop(key=labels_key, roi=crop_roi)
    pipeline += RenumberConnectedComponents(labels=labels_key)
    train = Train(
        graph='train/unet/train_net',
        # optimizer=config['optimizer'],
        optimizer=add_malis_loss,
        # loss=config['loss'],
        loss=None,
        inputs={
            config['raw']: raw_key,
            "gt_seg:0": labels_key,
            "gt_affs_mask:0": gt_affs_mask,
            config['gt_affs']: gt_affs_out_key,
            config['pred_affs_loss_weights']: input_affinities_scale_key,
        },
        outputs={
            config['pred_affs']: pred_affinities_key,
            config['debug']: debug_key,
        },
        gradients={config['pred_affs']: pred_affinities_gradient_key},
        summary="malis_loss:0",
        log_dir='log/unet',
        save_every=1)
    pipeline += train
    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)
    # Snapshot(
    # 	dataset_names={
    # 		labels_key: 'volumes/labels',
    # 		raw_key: 'volumes/raw',
    # 		gt_affs_out_key: 'volumes/gt_affs',
    # 		pred_affinities_key: 'volumes/pred_affs'
    # 	},
    # 	output_filename='unet/train/batch_{iteration}.hdf',
    # 	every=100,
    # 	dataset_dtypes={
    # 		raw_key: np.float32,
    # 		labels_key: np.uint64
    # 	}) +
    pipeline += PrintProfilingStats(every=8)

    print("Starting training... COOL BEANS")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
            pred_affs = req[pred_affinities_key].data
            debug = req[debug_key].data
            print("debug", debug)
            print('pred_affs', pred_affs)
            print("name of pred_adds: ", req[pred_affinities_key])
            # print("train session: ", train.session)
            # print ("all vars: ", [n.name for n in tf.get_default_graph().as_graph_def().node])
            # graph_def = tf.graph_util.convert_variables_to_constants(train.session, tf.get_default_graph().as_graph_def(), ["pred_affs:0".split(':')[0]])
            # print ("labels: ", req[labels_key].data.shape)
            # print ("affinities_out: ", req[gt_affs_out_key].data.shape)
            # print ("affinities_joined: ", req[joined_affinities_key].data.shape)
            # print ("raw: ", req[raw_key].data.shape)
            # print ("affinities_in_scale: ", req[input_affinities_scale_key].data.shape)
    print("Training finished")