Exemplo n.º 1
0
def train(iterations):

	# tf.reset_default_graph()
	if tf.train.latest_checkpoint('.'):
		trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
	else:
		trained_until = 0
	if trained_until >= iterations:
		return

	if trained_until < phase_switch and iterations > phase_switch:
		train(phase_switch)

	phase = 'euclid' if iterations <= phase_switch else 'malis'
	print("Training in phase %s until %i"%(phase, iterations))
		
	# define array-keys
	labels_key = ArrayKey('LABELS')	
	raw_affs_key = ArrayKey('RAW_AFFINITIES')
	raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
	raw_key = ArrayKey('RAW')
	
	merged_labels_keys = []
	# merged_affs_keys = []
	picked_labels_key = ArrayKey('PICKED_RANDOM_LABEL')

	affs_neg_key = ArrayKey('AFFINITIES')
	affs_pos_key = ArrayKey('AFFINITIES_OPP')
	joined_affs_neg_key = ArrayKey('JOINED_AFFINITIES')
	joined_affs_pos_key = ArrayKey('JOINED_AFFINITIES_OPP')

	num_merges = 3
	for i in range(num_merges):
		merged_labels_keys.append(ArrayKey('MERGED_LABELS_%i'%(i+1)))


	gt_affs_out_key = ArrayKey('GT_AFFINITIES')
	gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
	gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')
	gt_affs_scale_key = ArrayKey('GT_AFFINITIES_SCALE')
	
	pred_affs_key = ArrayKey('PRED_AFFS')
	pred_affs_gradient_key = ArrayKey('PRED_AFFS_GRADIENT')

	sample_z_key = ArrayKey("SAMPLE_Z")
	broadcast_key = ArrayKey("BROADCAST")
	pred_logits_key = ArrayKey("PRED_LOGITS")
	sample_out_key = ArrayKey("SAMPLE_OUT")
	debug_key = ArrayKey("DEBUG")

	voxel_size = Coordinate((1, 1, 1))
	input_shape = Coordinate(config['input_shape']) * voxel_size
	input_affs_shape = Coordinate([i + 1 for i in config['input_shape']]) * voxel_size
	output_shape = Coordinate(config['output_shape']) * voxel_size
	output_affs_shape = Coordinate([i + 1 for i in config['output_shape']]) * voxel_size
	sample_shape = Coordinate((1, 1, config['latent_dims'])) * voxel_size
	debug_shape = Coordinate((1, 1, 5)) * voxel_size

	print ("input_shape: ", input_shape)
	print ("input_affs_shape: ", input_affs_shape)
	print ("output_shape: ", output_shape)
	print ("output_affs_shape: ", output_affs_shape)

	request = BatchRequest()
	request.add(labels_key, input_shape)

	request.add(raw_affs_key, input_shape)
	request.add(raw_joined_affs_key, input_shape)
	request.add(raw_key, input_shape)

	for i in range(num_merges): 
		request.add(merged_labels_keys[i], input_shape)
	request.add(picked_labels_key, output_shape)		

	request.add(gt_affs_out_key, output_shape)
	request.add(gt_affs_in_key, input_shape)
	request.add(gt_affs_mask_key, output_shape)
	request.add(gt_affs_scale_key, output_shape)

	request.add(pred_affs_key, output_shape)
	request.add(pred_affs_gradient_key, output_shape)

	request.add(broadcast_key, output_shape)
	request.add(sample_z_key, sample_shape)
	request.add(pred_logits_key, output_shape)
	request.add(sample_out_key, sample_shape)
	request.add(debug_key, debug_shape)

	dataset_names = {
		labels_key: 'volumes/labels',
	}

	array_specs = {
		labels_key: ArraySpec(interpolatable=False)
	}

	for i in range(num_merges):
		dataset_names[merged_labels_keys[i]] = 'volumes/merged_labels_%i'%(i+1)
		array_specs[merged_labels_keys[i]] = ArraySpec(interpolatable=False)

	pipeline = tuple(
		Hdf5Source(
            os.path.join(data_dir, sample + '.hdf'),
            datasets = dataset_names,
            array_specs = array_specs
        ) +
        Pad(labels_key, None) +
        Pad(merged_labels_keys[0], None) +
        Pad(merged_labels_keys[1], None) +
        Pad(merged_labels_keys[2], None)
        # Pad(merged_labels_key[i], None) for i in range(num_merges) # don't know why this doesn't work
        for sample in samples
	)

	pipeline += RandomProvider()

	pipeline += AddAffinities(
			affinity_neighborhood=neighborhood,
			labels=labels_key,
			affinities=raw_affs_key)

	pipeline += AddJoinedAffinities(
			input_affinities=raw_affs_key,
			joined_affinities=raw_joined_affs_key)

	pipeline += AddRealism(
			joined_affinities = raw_joined_affs_key,
			raw = raw_key,
			sp=0.25,
			sigma=1,
			contrast=0.7)

	if phase == "euclid":

		pipeline += PickRandomLabel(
				input_labels = [labels_key]+ merged_labels_keys,
				output_label=picked_labels_key,
				probabilities=[1, 0, 0, 0])

	else: 

		pipeline += PickRandomLabel(
				input_labels = [labels_key] + merged_labels_keys,
				output_label=picked_labels_key,
				probabilities=[0.25, 0.25, 0.25, 0.25])

		pipeline += RenumberConnectedComponents(
				labels=picked_labels_key)


	pipeline += GrowBoundary(picked_labels_key, steps=1, only_xy=True)

	pipeline += AddAffinities(
			affinity_neighborhood=neighborhood,
			labels=picked_labels_key,
			affinities=gt_affs_in_key)

	pipeline += AddAffinities(
			affinity_neighborhood=neighborhood,
			labels=picked_labels_key,
			affinities=gt_affs_out_key,
			affinities_mask=gt_affs_mask_key)

	pipeline += BalanceLabels(
			labels=gt_affs_out_key,
			scales=gt_affs_scale_key)

	pipeline += DefectAugment(
			intensities=raw_key,
			prob_missing=0.03,
			prob_low_contrast=0.01,
			contrast_scale=0.5,
			axis=0)

	pipeline += IntensityScaleShift(raw_key, 2,-1)

	pipeline += PreCache(
			cache_size=8,
			num_workers=4)

	train_inputs = {
		config['raw']: raw_key,
		config['gt_affs_in']: gt_affs_in_key,
		config['gt_affs_out']: gt_affs_out_key,
		config['pred_affs_loss_weights']: gt_affs_scale_key
	}

	if phase == 'euclid':
		train_loss = config['loss']
		train_optimizer = config['optimizer']
		train_summary = config['summary']
	else:
		train_loss = None
		train_optimizer = add_malis_loss
		train_inputs['gt_seg:0'] = picked_labels_key
		train_inputs['gt_affs_mask:0'] = gt_affs_mask_key
		train_summary = 'Merge/MergeSummary:0'

	pipeline += Train(
			graph=setup_dir + 'train_net',
			optimizer=train_optimizer,
			loss=train_loss,
			inputs=train_inputs,
			outputs={
				config['pred_affs']: pred_affs_key,
				config['broadcast']: broadcast_key,
				config['sample_z']: sample_z_key,
				config['pred_logits']: pred_logits_key,
				config['sample_out']: sample_out_key,
				config['debug']: debug_key
			},
			gradients={
				config['pred_affs']: pred_affs_gradient_key
			},
			summary=train_summary,
			log_dir='log/prob_unet/' + setup_name,
			save_every=2000)

	pipeline += IntensityScaleShift(
			array=raw_key,
			scale=0.5,
			shift=0.5)

	pipeline += Snapshot(
			dataset_names={
				labels_key: 'volumes/labels',
				picked_labels_key: 'volumes/merged_labels',
				raw_affs_key: 'volumes/raw_affs',
				raw_key: 'volumes/raw',
				gt_affs_in_key: 'volumes/gt_affs_in',
				gt_affs_out_key: 'volumes/gt_affs_out',
				pred_affs_key: 'volumes/pred_affs',
				pred_logits_key: 'volumes/pred_logits'
			},
			output_filename='prob_unet/' + setup_name + '/batch_{iteration}.hdf',
			every=2000,
			dataset_dtypes={
				labels_key: np.uint16,
				picked_labels_key: np.uint16,
				raw_affs_key: np.float32,
				raw_key: np.float32,
				gt_affs_in_key: np.float32,
				gt_affs_out_key: np.float32,
				pred_affs_key: np.float32,
				pred_logits_key: np.float32
			})

	pipeline += PrintProfilingStats(every=20)

	print("Starting training...")
	with build(pipeline) as p:
		for i in range(iterations - trained_until):
			req = p.request_batch(request)
			# sample_z = req[sample_z_key].data
			# broadcast_sample = req[broadcast_key].data
			# sample_out = req[sample_out_key].data
			# debug = req[debug_key].data
			# print("debug", debug)

			# print("sample_z: ", sample_z)
			# print("sample_out:", sample_out)
			# print("Z - 0")
			# print(np.unique(broadcast_sample[0, 0, :, :, :]))
			# print("Z - 1")
			# print(np.unique(broadcast_sample[0, 1, :, :, :]))
			# print("Z - 2")
			# print(np.unique(broadcast_sample[0, 2, :, :, :]))
			# print("Z - 3")
			# print(np.unique(broadcast_sample[0, 3, :, :, :]))
			# print("Z - 4")
			# print(np.unique(broadcast_sample[0, 4, :, :, :]))
			# print("Z - 5")
			# print(np.unique(broadcast_sample[0, 5, :, :, :]))
	print("Training finished")
Exemplo n.º 2
0
def generate_full_samples(num_batches):

    neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]
    neighborhood_opp = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]

    # define array-keys
    labels_key = ArrayKey('LABELS')
    gt_affs_key = ArrayKey('RAW_AFFINITIES')
    gt_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    merged_labels_key = []
    merged_affs_key = []

    affs_neg_key = ArrayKey('AFFINITIES')
    affs_pos_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_neg_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_pos_key = ArrayKey('JOINED_AFFINITIES_OPP')

    picked_labels_key = ArrayKey('PICKED_LABELS')

    num_merges = 3
    for i in range(num_merges):
        merged_labels_key.append(ArrayKey('MERGED_LABELS_%i' % (i + 1)))
        merged_affs_key.append(ArrayKey('MERGED_AFFINITIES_IN_%i' % (i + 1)))

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate((132, 132, 132)) * voxel_size
    input_affs_shape = Coordinate([i + 1
                                   for i in (132, 132, 132)]) * voxel_size
    output_shape = Coordinate((44, 44, 44)) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in (44, 44, 44)]) * voxel_size

    print("input_shape: ", input_shape)
    print("output_shape: ", output_shape)

    request = BatchRequest()
    request.add(labels_key, input_shape)

    # request.add(gt_affs_key, input_shape)
    # request.add(gt_joined_affs_key, input_shape)
    # request.add(raw_key, input_shape)

    request.add(affs_neg_key, input_affs_shape)
    request.add(affs_pos_key, input_affs_shape)
    request.add(joined_affs_neg_key, input_affs_shape)
    request.add(joined_affs_pos_key, input_affs_shape)

    for i in range(num_merges):
        request.add(merged_labels_key[i], input)
        request.add(merged_affs_key[i], input)

    request.add(picked_labels_key, input)

    pipeline = ()

    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    # pipeline += AddAffinities(
    # 		affinity_neighborhood=neighborhood,
    # 		labels=labels_key,
    # 		affinities=gt_affs_key)

    # pipeline += AddJoinedAffinities(
    # 		input_affinities=gt_affs_key,
    # 		joined_affinities=gt_joined_affs_key)

    # pipeline += AddRealism(
    # 		joined_affinities = gt_joined_affs_key,
    # 		raw = raw_key,
    # 		sp=0.25,
    # 		sigma=1,
    # 		contrast=0.7)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=affs_neg_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood_opp,
                              labels=labels_key,
                              affinities=affs_pos_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_neg_key,
                                    joined_affinities=joined_affs_neg_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_pos_key,
                                    joined_affinities=joined_affs_pos_key)

    for i in range(num_merges):

        pipeline += MergeLabels(labels=labels_key,
                                joined_affinities=joined_affs_neg_key,
                                joined_affinities_opp=joined_affs_pos_key,
                                merged_labels=merged_labels_key[i],
                                cropped_roi=None,
                                every=1)

        pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                                  labels=merged_labels_key[i],
                                  affinities=merged_affs_key[i])

    pipeline += PickRandomLabel(input_labels=[labels_key] + merged_labels_key,
                                output_label=picked_labels_key,
                                probabilities=[0, 1, 0, 0])

    pipeline += RenumberConnectedComponents(labels=merged_labels_key[0])

    pipeline += GrowBoundary(merged_labels_key[0], steps=1, only_xy=True)

    # pipeline += PreCache(
    # 	cache_size=32,
    # 	num_workers=8)

    dataset_names = {
        labels_key: 'volumes/labels',
        picked_labels_key: 'volumes/picked'
    }

    dataset_dtypes = {
        labels_key: np.uint16,
        picked_labels_key: np.uint16,
    }

    for i in range(num_merges):
        dataset_names[merged_labels_key[i]] = 'volumes/merged_labels_%i' % (i +
                                                                            1)
        dataset_dtypes[merged_labels_key[i]] = np.uint64

    pipeline += Snapshot(dataset_names=dataset_names,
                         output_filename='debug_labels.hdf',
                         every=1,
                         dataset_dtypes=dataset_dtypes)

    # pipeline += PrintProfilingStats(every=10)

    hashes = []
    with build(pipeline) as p:
        for i in range(num_batches):
            req = p.request_batch(request)
            # print ("labels shape: ", req[labels_key].data.shape)
            picked_labels = len(np.unique(req[picked_labels_key].data))
            print("\nDATA POINT:", i, ", num labels:", picked_labels)