Beispiel #1
0
def np_seg_to_aff(seg, nhood=malis.mknhood3d(1)):
    # return lambda seg, nhood: malis.seg_to_affgraph (seg, nhood).astype(np.float32)
    seg = np.squeeze(seg)
    seg = seg.astype(np.int32)
    ret = malis.seg_to_affgraph(seg, nhood)  # seg zyx
    ret = ret.astype(np.float32)
    ret = np.squeeze(ret)  # ret 3zyx
    ret = np.transpose(ret, [1, 2, 3, 0])  # ret zyx3
    return ret
Beispiel #2
0
def mknhood_long():
    ret = malis.mknhood3d(1).tolist()
    for dz in [2, 3, 4]:
        ret.append([-dz, 0, 0])
    for dy in [3, 9, 27]:
        ret.append([0, -dy, 0])
    for dx in [3, 9, 27]:
        ret.append([0, 0, -dx])
    return np.array(ret, dtype=np.int32)
 def get_datasets(dataset, origin):
     # output_shape = (40, 30, 80)
     # input_shape = (100, 110, 120)
     output_shape = (2, 3, 4)
     input_shape = tuple([x + 2 for x in output_shape])
     borders = tuple([(in_ - out_) / 2
                      for (in_, out_) in zip(input_shape, output_shape)])
     input_slices = tuple(
         [slice(x, x + l) for x, l in zip(origin, input_shape)])
     output_slices = tuple([
         slice(x + b, x + b + l)
         for x, b, l in zip(origin, borders, output_shape)
     ])
     expected_dataset = dict()
     data_slices = [slice(0, l) for l in dataset['data'].shape]
     data_slices[-3:] = input_slices
     data_slices = tuple(data_slices)
     expected_data_array = np.array(dataset['data'][data_slices],
                                    dtype=np.float32)
     expected_data_array = expected_data_array.reshape((1, ) + input_shape)
     expected_data_array /= (2.0**8)
     expected_dataset['data'] = expected_data_array
     components_slices = [slice(0, l) for l in dataset['components'].shape]
     components_slices[-3:] = output_slices
     components_slices = tuple(components_slices)
     expected_components_array = np.array(
         dataset['components'][components_slices]).reshape((1, ) +
                                                           output_shape)
     if type(dataset['components']) is DVIDDataInstance:
         print("Is DVIDDataInstance...")
         print("uniques before:", np.unique(expected_components_array))
         dvid_uuid = dataset['components'].uuid
         body_names_to_exclude = dataset.get('body_names_to_exclude')
         good_bodies = get_good_components(dvid_uuid, body_names_to_exclude)
         expected_components_array = \
             replace_array_except_whitelist(expected_components_array, 0, good_bodies)
         print("uniques after:", np.unique(expected_components_array))
     expected_dataset['components'] = expected_components_array
     components_for_affinity_generation = expected_components_array.reshape(
         output_shape)
     expected_label = malis.seg_to_affgraph(
         components_for_affinity_generation, malis.mknhood3d())
     expected_dataset['label'] = expected_label
     if type(dataset['components']) is DVIDDataInstance:
         expected_mask = np.array(expected_components_array > 0).astype(
             np.uint8)
     else:
         expected_mask = np.ones(shape=(1, ) + output_shape, dtype=np.uint8)
     expected_dataset['mask'] = expected_mask
     numpy_dataset = get_numpy_dataset(dataset, input_slices, output_slices,
                                       True)
     return numpy_dataset, expected_dataset
Beispiel #4
0
def seg_to_aff_op(seg, nhood=tf.constant(malis.mknhood3d(1)), name='SegToAff'):
	# Squeeze the segmentation to 3D
	seg = tf.squeeze(seg, axis=-1)
	# Define the numpy function to transform segmentation to affinity graph
	np_func = lambda seg, nhood: malis.seg_to_affgraph (seg.astype(np.int32), nhood).astype(np.float32)
	# Convert the numpy function to tensorflow function
	tf_func = tf.py_func(np_func, [tf.cast(seg, tf.int32), nhood], [tf.float32], name=name)
	# Reshape the result, notice that layout format from malis is 3, dimx, dimy, dimx
	ret = tf.reshape(tf_func[0], [3, seg.shape[0], seg.shape[1], seg.shape[2]])
	# Transpose the result so that the dimension 3 go to the last channel
	ret = tf.transpose(ret, [1, 2, 3, 0])
	# print ret.get_shape().as_list()
	return ret
Beispiel #5
0
def aff_to_seg_op(aff, nhood=tf.constant(malis.mknhood3d(1)), threshold=tf.constant(np.array([0.5])), name='AffToSeg'):
	# Define the numpy function to transform affinity to segmentation
	def np_func (aff, nhood, threshold):
		aff = np.transpose(aff, [3, 0, 1, 2]) # zyx3 to 3zyx
		ret = malis.connected_components_affgraph((aff > threshold[0]).astype(np.int32), nhood)[0].astype(np.int32) 
		ret = skimage.measure.label(ret).astype(np.int32)
		return ret
	# print aff.get_shape().as_list()
	# Convert numpy function to tensorflow function
	tf_func = tf.py_func(np_func, [aff, nhood, threshold], [tf.int32], name=name)
	ret = tf.reshape(tf_func[0], [aff.shape[0], aff.shape[1], aff.shape[2]])
	ret = tf.expand_dims(ret, axis=-1)
	# print ret.get_shape().as_list()
	return ret
Beispiel #6
0
def train():

    set_verbose()

    affinity_neighborhood = malis.mknhood3d()
    solver_parameters = SolverParameters()
    solver_parameters.train_net = 'net.prototxt'
    solver_parameters.base_lr = 1e-4
    solver_parameters.momentum = 0.95
    solver_parameters.momentum2 = 0.999
    solver_parameters.delta = 1e-8
    solver_parameters.weight_decay = 0.000005
    solver_parameters.lr_policy = 'inv'
    solver_parameters.gamma = 0.0001
    solver_parameters.power = 0.75
    solver_parameters.snapshot = 10
    solver_parameters.snapshot_prefix = 'net'
    solver_parameters.type = 'Adam'
    solver_parameters.resume_from = None
    solver_parameters.train_state.add_stage('euclid')

    batch_spec = BatchSpec((84, 268, 268), (56, 56, 56),
                           with_gt=True,
                           with_gt_mask=False,
                           with_gt_affinities=True)

    # create a batch provider by concatenation of filters
    batch_provider = (
        DvidSource(
            hostname='slowpoke3',
            port=32788,
            uuid='341',
            raw_array_name='grayscale',
            gt_array_name='groundtruth_pruned',
        ) + Normalize() + RandomLocation() +
        AddGtAffinities(affinity_neighborhood) +
        PreCache(lambda: batch_spec, cache_size=3, num_workers=2) +
        Train(solver_parameters, use_gpu=0) + Snapshot(every=1))

    n = 20
    print("Training for", n, "iterations")

    with build(batch_provider) as minibatch_maker:
        for i in range(n):
            minibatch_maker.request_batch(batch_spec)

    print("Finished")
def malis_3d(aff, gt):

    assert aff.ndim == 4, "affinitygraph needs to be 4 dimensional"
    assert aff.shape[0] == 3, "affinitygraph channel 0 needs z,y and x affinities"
    assert aff.shape[1:] == gt.shape, "Spatial shapes of affinity graph and gt need to be the same"

    # import all the malis functionality we need
    from malis import mknhood3d, affgraph_to_edgelist, malis_loss_weights

    # need to ravel the gt
    gt = gt.ravel()

    # make the 3d neighborhood
    nhood = mknhood3d()

    # get the node connectors and weights from the affinitygraph
    # extracts the uvIds and weights from the affinity graph for the given neighborhood
    uvs1, uvs2, edge_weights = affgraph_to_edgelist(aff, nhood)

    print "nodes and weights from affinity graph"
    print "connectors shape:", uvs1.shape
    print "edge weights shape:", edge_weights.shape

    # malis loss:
    # calculates number of correct / false merges caused by edge

    # parameters:
    # gt : groundtruth (raveled)
    # connectors1 : uvIds(0)
    # connectors2 : uvIds(1) -> the two nodeconnectors tell the nodes that are connected by the edge
    # edge_weights : raveled edge weights from the affinity graph
    # pos : pseudo bool, that determnies whether we count correct (pos = 1) or false (pos = 0) merges per edge
    pos = 0
    malis_loss_false_merges = malis_loss_weights(gt, uvs1, uvs22,
            edge_weights, pos)

    print "Calculated false merges per edge"

    pos = 1
    malis_loss_correct_merges = malis_loss_weights(gt, uvs1, uvs2,
            edge_weights, pos)

    print "Calculated correct merges per edge"

    return malis_loss_false_merges, malis_loss_correct_merges
Beispiel #8
0
def compute_malis_counts(affinities, labels):
    nhood = m.mknhood3d(radius=1)
    assert nhood.shape[0] == affinities.shape[0]
    subvolume_shape = labels.shape
    node_idx_1, node_idx_2 = m.nodelist_like(subvolume_shape, nhood)
    node_idx_1, node_idx_2 = node_idx_1.ravel(), node_idx_2.ravel()
    flat_labels = labels[...].ravel()
    flat_affinties = affinities[...].ravel()
    pos_counts = m.malis_loss_weights(flat_labels,
                                      node_idx_1, node_idx_2,
                                      flat_affinties,
                                      1)
    neg_counts = m.malis_loss_weights(flat_labels,
                                      node_idx_1, node_idx_2,
                                      flat_affinties,
                                      0)
    pos_counts = pos_counts.reshape(affinities.shape)
    neg_counts = neg_counts.reshape(affinities.shape)
    return pos_counts, neg_counts
Beispiel #9
0
def get_train_dataset(dataset_source_type_, using_in_memory=False):
    train_dataset = []
    for dname in training_dataset_names:
        dataset = dict()
        h5_filenames = dict(
            data=join(path_to_training_datasets, dname, 'im_uint8.h5'),
            components=join(path_to_training_datasets, dname,
                            'groundtruth_seg_thick.h5'),
            # label=join(path_to_training_datasets, dname, 'groundtruth_aff.h5'),
            mask=join(path_to_training_datasets, dname, 'mask.h5'),
        )
        dvid_data_names = dict(
            data='grayscale',
            components='labels',
        )
        dvid_hostname = 'emdata2.int.janelia.org'
        dvid_port = 7000
        dataset['name'] = dname
        dataset['nhood'] = malis.mknhood3d()
        for key in ['data', 'components']:
            if dataset_source_type_ == DVIDDataInstance:
                if key in dvid_data_names:
                    data_name = dvid_data_names[key]
                    dataset[key] = DVIDDataInstance(dvid_hostname, dvid_port,
                                                    dvid_uuid, data_name)
            elif dataset_source_type_ == h5py.File:
                dataset[key] = h5py.File(h5_filenames[key], 'r')['main']
                if using_in_memory:
                    dataset[key] = np.array(dataset[key])
                    if key != 'label':
                        dataset[key] = dataset[key].reshape((1, ) +
                                                            dataset[key].shape)
                    if key == 'data':
                        dataset[key] = dataset[key] / 2.**8
            elif dataset_source_type_ == 'hdf5 file paths':
                if key in h5_filenames:
                    dataset[key] = h5_filenames[key]
        # dataset['transform'] = {}
        # dataset['transform']['scale'] = (0.8, 1.2)
        # dataset['transform']['shift'] = (-0.2, 0.2)
        train_dataset.append(dataset)
    return train_dataset
Beispiel #10
0
def get_counts(aff, gt, 
               ignore_background=False,
               counting_method=0,
               stochastic_malis_parameter=0,
               z_transform=False):

    if z_transform:
        raise NotImplementedError("z transform not implemented")
    nhood = m.mknhood3d()
    
    node_idx_1, node_idx_2 = m.nodelist_like(gt.shape, nhood)
    pos_pairs, neg_pairs = m.malis_loss_weights( \
        gt.flatten(),
        node_idx_1.flatten(),
        node_idx_2.flatten(),
        aff.flatten().astype(np.float32),
        ignore_background=ignore_background,
        counting_method=counting_method,
        stochastic_malis_parameter=stochastic_malis_parameter)
    return pos_pairs.reshape(aff.shape), neg_pairs.reshape(aff.shape)
 def get_datasets(dataset, origin):
     # output_shape = (40, 30, 80)
     # input_shape = (100, 110, 120)
     output_shape = (2,3,4)
     input_shape = tuple([x + 2 for x in output_shape])
     borders = tuple([(in_ - out_) / 2 for (in_, out_) in zip(input_shape, output_shape)])
     input_slices = tuple([slice(x, x + l) for x, l in zip(origin, input_shape)])
     output_slices = tuple([slice(x + b, x + b + l) for x, b, l in zip(origin, borders, output_shape)])
     expected_dataset = dict()
     data_slices = [slice(0, l) for l in dataset['data'].shape]
     data_slices[-3:] = input_slices
     data_slices = tuple(data_slices)
     expected_data_array = np.array(dataset['data'][data_slices], dtype=np.float32)
     expected_data_array = expected_data_array.reshape((1,) + input_shape)
     expected_data_array /= (2.0 ** 8)
     expected_dataset['data'] = expected_data_array
     components_slices = [slice(0, l) for l in dataset['components'].shape]
     components_slices[-3:] = output_slices
     components_slices = tuple(components_slices)
     expected_components_array = np.array(dataset['components'][components_slices]).reshape((1,) + output_shape)
     if type(dataset['components']) is DVIDDataInstance:
         print("Is DVIDDataInstance...")
         print("uniques before:", np.unique(expected_components_array))
         dvid_uuid = dataset['components'].uuid
         body_names_to_exclude = dataset.get('body_names_to_exclude')
         good_bodies = get_good_components(dvid_uuid, body_names_to_exclude)
         expected_components_array = \
             replace_array_except_whitelist(expected_components_array, 0, good_bodies)
         print("uniques after:", np.unique(expected_components_array))
     expected_dataset['components'] = expected_components_array
     components_for_affinity_generation = expected_components_array.reshape(output_shape)
     expected_label = malis.seg_to_affgraph(components_for_affinity_generation, malis.mknhood3d())
     expected_dataset['label'] = expected_label
     if type(dataset['components']) is DVIDDataInstance:
         expected_mask = np.array(expected_components_array > 0).astype(np.uint8)
     else:
         expected_mask = np.ones(shape=(1,) + output_shape, dtype=np.uint8)
     expected_dataset['mask'] = expected_mask
     numpy_dataset = get_numpy_dataset(dataset, input_slices, output_slices, True)
     return numpy_dataset, expected_dataset
Beispiel #12
0
def get_train_dataset(dataset_source_type_, using_in_memory=False):
    train_dataset = []
    for dname in training_dataset_names:
        dataset = dict()
        h5_filenames = dict(
            data=join(path_to_training_datasets, dname, 'im_uint8.h5'),
            components=join(path_to_training_datasets, dname, 'groundtruth_seg_thick.h5'),
            # label=join(path_to_training_datasets, dname, 'groundtruth_aff.h5'),
            mask=join(path_to_training_datasets, dname, 'mask.h5'),
        )
        dvid_data_names = dict(
            data='grayscale',
            components='labels',
        )
        dvid_hostname = 'emdata2.int.janelia.org'
        dvid_port = 7000
        dataset['name'] = dname
        dataset['nhood'] = malis.mknhood3d()
        for key in ['data', 'components']:
            if dataset_source_type_ == DVIDDataInstance:
                if key in dvid_data_names:
                    data_name = dvid_data_names[key]
                    dataset[key] = DVIDDataInstance(dvid_hostname, dvid_port, dvid_uuid, data_name)
            elif dataset_source_type_ == h5py.File:
                dataset[key] = h5py.File(h5_filenames[key], 'r')['main']
                if using_in_memory:
                    dataset[key] = np.array(dataset[key])
                    if key != 'label':
                        dataset[key] = dataset[key].reshape((1,) + dataset[key].shape)
                    if key == 'data':
                        dataset[key] = dataset[key] / 2. ** 8
            elif dataset_source_type_ == 'hdf5 file paths':
                if key in h5_filenames:
                    dataset[key] = h5_filenames[key]
        # dataset['transform'] = {}
        # dataset['transform']['scale'] = (0.8, 1.2)
        # dataset['transform']['shift'] = (-0.2, 0.2)
        train_dataset.append(dataset)
    return train_dataset
Beispiel #13
0
def ignore_disconnected_components(cells):

    global ignore_label
    if ignore_label is None:
        ignore_label = int(cells.max() + 1)

    print("Relabelling connected components...")
    simple_neighborhood = malis.mknhood3d()
    affinities = malis.seg_to_affgraph(cells, simple_neighborhood)
    relabelled, _ = malis.connected_components_affgraph(
        affinities, simple_neighborhood)

    print("Creating overlay...")
    overlay = np.array([cells.flatten(), relabelled.flatten()])
    print("Finding unique pairs...")
    matches = np.unique(overlay, axis=1)

    print("Finding disconnected labels...")
    orig_to_new = {}
    disconnected = set()
    for orig_id, new_id in zip(matches[0], matches[1]):
        if orig_id == 0 or new_id == 0:
            continue
        if orig_id not in orig_to_new:
            orig_to_new[orig_id] = [new_id]
        else:
            orig_to_new[orig_id].append(new_id)
            disconnected.add(orig_id)

    print("Masking %d disconnected labels..." % len(disconnected))
    ignore_mask = replace(cells, np.array([l for l in disconnected]),
                          np.array([ignore_label], dtype=np.uint64))
    ignore_mask = (ignore_mask == ignore_label).astype(np.uint8)
    print("done.")

    return ignore_mask
                for swapxy in range(2):
                    new_dname = '{dname}_z{z}_y{y}_x{x}_xy{swapxy}_angle{angle:05.1f}'.format(
                        dname=dname, z=reflectz, y=reflecty, x=reflectx, swapxy=swapxy,
                        angle=angle
                    )
                    names.append(new_dname)
base_dir = '/data/SNEMI_aug' #'/nobackup/turaga/singhc/SNEMI_aug/rotations'
for name in names:
    dataset = dict()
    dataset['name'] = name
    dataset['data'] = h5py.File(os.path.join(base_dir, 'raw.h5'), 'r')[name]
    dataset['components'] = h5py.File(os.path.join(base_dir, 'seg.h5'), 'r')[name]
    dataset['mask'] = h5py.File(os.path.join(base_dir, 'mask.h5'), 'r')[name]
    train_dataset.append(dataset)
                    
for dataset in train_dataset:
    dataset['nhood'] = malis.mknhood3d()
    dataset['mask_threshold'] = mask_threshold
    dataset['mask_dilation_steps'] = mask_dilation_steps
    dataset['transform'] = {}
    dataset['transform']['scale'] = (0.8, 1.2)
    dataset['transform']['shift'] = (-0.2, 0.2)

print('Training set contains',
      str(len(train_dataset)),
      'volumes:',
      [dataset['name'] for dataset in train_dataset],
      train_dataset)

## Testing datasets
test_dataset = []
Beispiel #15
0
def np_aff_to_seg(aff, nhood=malis.mknhood3d(1), threshold=np.array([0.5])):
    aff = np.transpose(aff, [3, 0, 1, 2])  # zyx3 to 3zyx
    ret = malis.connected_components_affgraph(
        (aff > threshold[0]).astype(np.int32), nhood)[0].astype(np.int32)
    ret = skimage.measure.label(ret).astype(np.float32)
    return ret
        dataset['component_erosion_steps'] = component_erosion_steps
        train_datasets.append(dataset)

# fib25 = dict(
#     name="FIB-25 train",
#     data=DVIDDataInstance("slowpoke3", 32788, "213", "grayscale"),
#     components=DVIDDataInstance("slowpoke3", 32788, "213", "groundtruth_pruned"),
#     image_scaling_factor=1.0 / (2.0 ** 8),
#     component_erosion_steps=component_erosion_steps,
#     bounding_box=((2000, 5006), (1000, 5000), (2000, 6000)),  # train region
#     dvid_body_names_to_exclude=dvid_body_names_to_exclude,
# )
# train_datasets.extend([fib25] * 8)

for dataset in train_datasets:
    dataset['nhood'] = malis.mknhood3d().astype(int)
    dataset['mask_threshold'] = mask_threshold
    dataset['mask_dilation_steps'] = mask_dilation_steps
    dataset['minimum_component_size'] = minimum_component_size
    dataset['simple_augment'] = simple_augmenting
    dataset['transform'] = {}
    dataset['transform']['scale'] = (0.9, 1.1)
    dataset['transform']['shift'] = (-0.1, 0.1)

print('Training set contains', len(train_datasets), 'volumes:',
      [dataset['name'] for dataset in train_datasets], "with dtype/shapes",
      [(array.dtype, array.shape)
       for array in [dataset[key] for key in ("data", "components")]
       for dataset in train_datasets])

## Testing datasets
Beispiel #17
0
def train(max_iteration, gpu, voxel_size):

    # get most recent training result
    solverstates = [
        int(f.split('.')[0].split('_')[-1])
        for f in glob.glob('net_iter_*.solverstate')
    ]
    if len(solverstates) > 0:
        trained_until = max(solverstates)
        print("Resuming training from iteration " + str(trained_until))
    else:
        trained_until = 0
        print("Starting fresh training")
    if trained_until < phase_switch and max_iteration > phase_switch:
        # phase switch lies in-between, split training into to parts
        train(max_iteration=phase_switch, gpu=gpu, voxel_size=voxel_size)
        trained_until = phase_switch

    # switch from euclidean to malis after "phase_switch" iterations
    if max_iteration <= phase_switch:
        phase = 'euclid'
    else:
        phase = 'malis'
    print("Training until " + str(max_iteration) + " in phase " + phase)

    # define request
    request = BatchRequest()
    shape_input = (132, 132, 132) * np.asarray(voxel_size)
    shape_output = (44, 44, 44) * np.asarray(voxel_size)
    request.add_volume_request(VolumeTypes.RAW, shape_input)
    request.add_volume_request(VolumeTypes.GT_LABELS, shape_output)
    request.add_volume_request(VolumeTypes.GT_AFFINITIES, shape_output)
    if phase == 'malis':
        request.add_volume_request(VolumeTypes.MALIS_COMP_LABEL, shape_output)
    request.add_volume_request(VolumeTypes.LOSS_SCALE, shape_output)
    request.add_volume_request(VolumeTypes.PRED_AFFINITIES, shape_output)
    request.add_points_request(PointsTypes.PRESYN, shape_output)
    request.add_points_request(PointsTypes.POSTSYN, shape_output)
    request.add_volume_request(VolumeTypes.GT_BM_PRESYN, shape_output)
    request.add_volume_request(VolumeTypes.GT_BM_POSTSYN, shape_output)
    request.add_volume_request(VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN,
                               shape_output)
    request.add_volume_request(VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN,
                               shape_output)
    request.add_volume_request(VolumeTypes.LOSS_SCALE_BM_PRESYN, shape_output)
    request.add_volume_request(VolumeTypes.LOSS_SCALE_BM_POSTSYN, shape_output)
    request.add_volume_request(VolumeTypes.PRED_BM_PRESYN, shape_output)
    request.add_volume_request(VolumeTypes.PRED_BM_POSTSYN, shape_output)

    # define settings binary mask (rasterization of synapse points to binary mask)
    volumetypes_to_pointstype = {
        VolumeTypes.GT_BM_PRESYN: PointsTypes.PRESYN,
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN: PointsTypes.PRESYN,
        VolumeTypes.GT_BM_POSTSYN: PointsTypes.POSTSYN,
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN: PointsTypes.POSTSYN
    }
    rastersetting_bm = RasterizationSetting(marker_size_physical=32)
    rastersetting_mask_syn = RasterizationSetting(marker_size_physical=96,
                                                  donut_inner_radius=32,
                                                  invert_map=True)
    volumetype_to_rastersettings = {
        VolumeTypes.GT_BM_PRESYN: rastersetting_bm,
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN: rastersetting_mask_syn,
        VolumeTypes.GT_BM_POSTSYN: rastersetting_bm,
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN: rastersetting_mask_syn,
    }

    # define names of datasets in snapshot file
    snapshot_dataset_names = {
        VolumeTypes.RAW: 'volumes/raw',
        VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids',
        VolumeTypes.GT_AFFINITIES: 'volumes/labels/affs',
        VolumeTypes.PRED_AFFINITIES: 'volumes/predicted_affs',
        VolumeTypes.LOSS_SCALE: 'volumes/loss_scale',
        VolumeTypes.LOSS_GRADIENT: 'volumes/predicted_affs_loss_gradient',
        VolumeTypes.GT_BM_PRESYN: 'volumes/labels/gt_bm_presyn',
        VolumeTypes.GT_BM_POSTSYN: 'volumes/labels/gt_bm_postsyn',
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN:
        'volumes/labels/gt_mask_exclusivezone_presyn',
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN:
        'volumes/labels/gt_mask_exclusivezone_postsyn',
        VolumeTypes.PRED_BM_PRESYN: 'volumes/predicted_bm_presyn',
        VolumeTypes.PRED_BM_POSTSYN: 'volumes/predicted_bm_postsyn',
        VolumeTypes.LOSS_SCALE_BM_PRESYN: 'volumes/loss_scale_presyn',
        VolumeTypes.LOSS_SCALE_BM_POSTSYN: 'volumes/loss_scale_postsyn',
    }

    # set network inputs, outputs and gradients
    train_inputs = {
        VolumeTypes.RAW: 'data',
        VolumeTypes.GT_AFFINITIES: 'aff_label',
        VolumeTypes.GT_BM_PRESYN: 'bm_presyn_label',
        VolumeTypes.GT_BM_POSTSYN: 'bm_postsyn_label',
        VolumeTypes.LOSS_SCALE: 'segm_scale',
        VolumeTypes.LOSS_SCALE_BM_PRESYN: 'bm_presyn_scale',
        VolumeTypes.LOSS_SCALE_BM_POSTSYN: 'bm_postsyn_scale',
    }
    if phase == 'malis':
        train_inputs[VolumeTypes.MALIS_COMP_LABEL] = 'comp_label'
        train_inputs['affinity_neighborhood'] = 'nhood'
    train_outputs = {
        VolumeTypes.PRED_BM_PRESYN: 'bm_presyn_pred',
        VolumeTypes.PRED_BM_POSTSYN: 'bm_postsyn_pred',
        VolumeTypes.PRED_AFFINITIES: 'aff_pred',
    }
    train_gradients = {
        VolumeTypes.LOSS_GRADIENT_PRESYN: 'bm_presyn_pred',
        VolumeTypes.LOSS_GRADIENT_POSTSYN: 'bm_postsyn_pred',
        VolumeTypes.LOSS_GRADIENT: 'aff_pred',
    }

    # set solver parameters
    solver_parameters = SolverParameters()
    solver_parameters.train_net = 'net.prototxt'
    solver_parameters.base_lr = 1e-4
    solver_parameters.momentum = 0.99
    solver_parameters.momentum2 = 0.999
    solver_parameters.delta = 1e-8
    solver_parameters.weight_decay = 0.000005
    solver_parameters.lr_policy = 'inv'
    solver_parameters.gamma = 0.0001
    solver_parameters.power = 0.75
    solver_parameters.snapshot = 10000
    solver_parameters.snapshot_prefix = 'net'
    solver_parameters.type = 'Adam'
    if trained_until > 0:
        solver_parameters.resume_from = 'net_iter_' + str(
            trained_until) + '.solverstate'
    else:
        solver_parameters.resume_from = None
    solver_parameters.train_state.add_stage(phase)

    # set source of data
    data_sources = tuple(
        DvidSource(
            hostname='slowpoke2',
            port=8000,
            uuid='cb7dc',
            volume_array_names={
                VolumeTypes.RAW: 'grayscale',
                VolumeTypes.GT_LABELS: 'labels'
            },
            points_array_names={
                PointsTypes.PRESYN: 'combined_synapses_08302016',
                PointsTypes.POSTSYN: 'combined_synapses_08302016'
            },
            points_rois={
                PointsTypes.PRESYN:
                Roi(offset=(76000, 20000, 64000), shape=(4000, 4000, 16000)),
                PointsTypes.POSTSYN:
                Roi(offset=(76000, 20000, 64000), shape=(4000, 4000, 16000))
            },
            points_voxel_size={
                PointsTypes.PRESYN: voxel_size,
                PointsTypes.POSTSYN: voxel_size
            }) + RandomLocation(focus_points_type=focus_points_type) +
        Normalize()
        for focus_points_type in (3 * [PointsTypes.PRESYN] + 2 * [None]))

    # define pipeline to process batches
    batch_provider_tree = (
        data_sources + RandomProvider() + RasterizePoints(
            volumetypes_to_pointstype=volumetypes_to_pointstype,
            volumetypes_to_rastersettings=volumetype_to_rastersettings) +
        ElasticAugment(control_point_spacing=[40, 40, 40],
                       jitter_sigma=[2, 2, 2],
                       rotation_interval=[0, math.pi / 2.0],
                       prob_slip=0.01,
                       prob_shift=0.01,
                       max_misalign=1,
                       subsample=8) + SimpleAugment(transpose_only_xy=True) +
        IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        GrowBoundary(steps=3, only_xy=True) +
        AddGtAffinities(malis.mknhood3d()) +
        SplitAndRenumberSegmentationLabels() + IntensityScaleShift(2, -1) +
        ZeroOutConstSections() + PrepareMalis() +
        BalanceLabels(labels_to_loss_scale_volume={
            VolumeTypes.GT_BM_PRESYN: VolumeTypes.LOSS_SCALE_BM_PRESYN,
            VolumeTypes.GT_BM_POSTSYN: VolumeTypes.LOSS_SCALE_BM_POSTSYN,
            VolumeTypes.GT_LABELS: VolumeTypes.LOSS_SCALE
        },
                      labels_to_mask_volumes={
                          VolumeTypes.GT_BM_PRESYN:
                          [VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN],
                          VolumeTypes.GT_BM_POSTSYN:
                          [VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN],
                      }) + PreCache(cache_size=40, num_workers=10) +
        Train(solver_parameters,
              inputs=train_inputs,
              outputs=train_outputs,
              gradients=train_gradients,
              use_gpu=gpu) + Snapshot(dataset_names=snapshot_dataset_names,
                                      every=5000,
                                      output_filename='batch_{id}.hdf',
                                      compression_type="gzip"))

    print("Training for", max_iteration, "iterations")
    with build(batch_provider_tree) as minibatch_maker:
        for i in range(max_iteration):
            minibatch_maker.request_batch(request)
    print("Finished")
Beispiel #18
0
def train():

    gunpowder.set_verbose(False)

    affinity_neighborhood = malis.mknhood3d()
    solver_parameters = gunpowder.caffe.SolverParameters()
    solver_parameters.train_net = 'net.prototxt'
    solver_parameters.base_lr = 1e-4
    solver_parameters.momentum = 0.95
    solver_parameters.momentum2 = 0.999
    solver_parameters.delta = 1e-8
    solver_parameters.weight_decay = 0.000005
    solver_parameters.lr_policy = 'inv'
    solver_parameters.gamma = 0.0001
    solver_parameters.power = 0.75
    solver_parameters.snapshot = 10000
    solver_parameters.snapshot_prefix = 'net'
    solver_parameters.type = 'Adam'
    solver_parameters.resume_from = None
    solver_parameters.train_state.add_stage('euclid')

    request = BatchRequest()
    request.add_volume_request(VolumeTypes.RAW, constants.input_shape)
    request.add_volume_request(VolumeTypes.GT_LABELS, constants.output_shape)
    request.add_volume_request(VolumeTypes.GT_MASK, constants.output_shape)
    request.add_volume_request(VolumeTypes.GT_AFFINITIES, constants.output_shape)
    request.add_volume_request(VolumeTypes.LOSS_SCALE, constants.output_shape)

    data_providers = list()
    fibsem_dir = "/groups/turaga/turagalab/data/FlyEM/fibsem_medulla_7col"
    for volume_name in ("tstvol-520-1-h5",):
        h5_filepath = "./{}.h5".format(volume_name)
        path_to_labels = os.path.join(fibsem_dir, volume_name, "groundtruth_seg.h5")
        with h5py.File(path_to_labels, "r") as f_labels:
            mask_shape = f_labels["main"].shape
        with h5py.File(h5_filepath, "w") as h5:
            h5['volumes/raw'] = h5py.ExternalLink(os.path.join(fibsem_dir, volume_name, "im_uint8.h5"), "main")
            h5['volumes/labels/neuron_ids'] = h5py.ExternalLink(path_to_labels, "main")
            h5.create_dataset(
                name="volumes/labels/mask",
                dtype="uint8",
                shape=mask_shape,
                fillvalue=1,
            )
        data_providers.append(
            gunpowder.Hdf5Source(
                h5_filepath,
                datasets={
                    VolumeTypes.RAW: 'volumes/raw',
                    VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids',
                    VolumeTypes.GT_MASK: 'volumes/labels/mask',
                },
                resolution=(8, 8, 8),
            )
        )
    dvid_source = DvidSource(
        hostname='slowpoke3',
        port=32788,
        uuid='341',
        raw_array_name='grayscale',
        gt_array_name='groundtruth',
        gt_mask_roi_name="seven_column_eroded7_z_lt_5024",
        resolution=(8, 8, 8),
    )
    data_providers.extend([dvid_source])
    data_providers = tuple(
        provider +
        RandomLocation() +
        Reject(min_masked=0.5) +
        Normalize()
        for provider in data_providers
    )

    # create a batch provider by concatenation of filters
    batch_provider = (
        data_providers +
        RandomProvider() +
        ElasticAugment([20, 20, 20], [0, 0, 0], [0, math.pi / 2.0]) +
        SimpleAugment(transpose_only_xy=False) +
        GrowBoundary(steps=2, only_xy=False) +
        AddGtAffinities(affinity_neighborhood) +
        BalanceAffinityLabels() +
        SplitAndRenumberSegmentationLabels() +
        IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=False) +
        PreCache(
            request,
            cache_size=11,
            num_workers=10) +
        Train(solver_parameters, use_gpu=0) +
        Typecast(volume_dtypes={
            VolumeTypes.GT_LABELS: np.dtype("uint32"),
            VolumeTypes.GT_MASK: np.dtype("uint8"),
            VolumeTypes.LOSS_SCALE: np.dtype("float32"),
        }, safe=True) +
        Snapshot(every=50, output_filename='batch_{id}.hdf') +
        PrintProfilingStats(every=50)
    )

    n = 500000
    print("Training for", n, "iterations")

    with gunpowder.build(batch_provider) as pipeline:
        for i in range(n):
            pipeline.request_batch(request)

    print("Finished")
Beispiel #19
0
def train_until(max_iteration, data_sources):
    data_providers = []
    fib25_dir = "/groups/saalfeld/home/funkej/workspace/projects/caffe/run/fib25/01_data/train"
    if 'fib25h5' in data_sources:

        for volume_name in ("tstvol-520-1", "tstvol-520-2", "trvol-250-1",
                            "trvol-250-2"):
            h5_source = Hdf5Source(os.path.join(fib25_dir,
                                                volume_name + '.hdf'),
                                   datasets={
                                       VolumeTypes.RAW: 'volumes/raw',
                                       VolumeTypes.GT_LABELS:
                                       'volumes/labels/neuron_ids',
                                       VolumeTypes.GT_MASK:
                                       'volumes/labels/mask',
                                   },
                                   volume_specs={
                                       VolumeTypes.GT_MASK:
                                       VolumeSpec(interpolatable=False)
                                   })
            data_providers.append(h5_source)

    fib19_dir = "/groups/saalfeld/saalfeldlab/larissa/fib19"
    if 'fib19h5' in data_sources:
        for volume_name in ("trvol-250", "trvol-600"):
            h5_source = prepare_h5source(fib19_dir, volume_name)
            data_providers.append(h5_source)

    #todo: dvid source

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    register_volume_type('RAW')
    register_volume_type('ALPHA_MASK')
    register_volume_type('GT_LABELS')
    register_volume_type('GT_MASK')
    register_volume_type('GT_SCALE')
    register_volume_type('GT_AFFINITIES')
    register_volume_type('PREDICTED_AFFS')
    register_volume_type('LOSS_GRADIENT')

    voxel_size = Coordinate((8, 8, 8))
    input_size = Coordinate((132, ) * 3) * voxel_size
    output_size = Coordinate((44, ) * 3) * voxel_size

    # specifiy which volumes should be requested for each batch
    request = BatchRequest()
    request.add(VolumeTypes.RAW, input_size)
    request.add(VolumeTypes.GT_LABELS, output_size)
    request.add(VolumeTypes.GT_MASK, output_size)
    request.add(VolumeTypes.GT_SCALE, output_size)
    request.add(VolumeTypes.GT_AFFINITIES, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize() +  # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad({
            VolumeTypes.RAW: Coordinate((100, 100, 100)) * voxel_size,
            VolumeTypes.GT_MASK: Coordinate((100, 100, 100)) * voxel_size
        }) + RandomLocation()
        +  # chose a random location inside the provided volumes
        Reject()  # reject batches wich do contain less than 50% labelled data
        for provider in data_providers)

    snapshot_request = BatchRequest({
        VolumeTypes.LOSS_GRADIENT:
        request[VolumeTypes.GT_LABELS],
        VolumeTypes.PREDICTED_AFFS:
        request[VolumeTypes.GT_LABELS],
        VolumeTypes.LOSS_GRADIENT:
        request[VolumeTypes.GT_AFFINITIES]
    })

    #artifact_source = (
    #    Hdf5Source(
    #        os.path.join(data_dir, 'sample_ABC_padded_20160501.defects.hdf'),
    #        datasets = {
    #            VolumeTypes.RAW: 'defect_sections/raw',
    #            VolumeTypes.ALPHA_MASK: 'defect_sections/mask',
    #        },
    #        volume_specs = {
    #            VolumeTypes.RAW: VolumeSpec(voxel_size=(40, 4, 4)),
    #            VolumeTypes.ALPHA_MASK: VolumeSpec(voxel_size=(40, 4, 4)),
    #        }
    #    ) +
    #    RandomLocation(min_masked=0.05, mask_volume_type=VolumeTypes.ALPHA_MASK) +
    #    Normalize() +
    #    IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
    #    ElasticAugment([4,40,40], [0,2,2], [0,math.pi/2.0], subsample=8) +
    #    SimpleAugment(transpose_only_xy=True)
    #)

    train_pipeline = (
        data_sources + RandomProvider() +
        ElasticAugment([40, 40, 40], [2, 2, 2], [0, math.pi / 2.0],
                       prob_slip=0.01,
                       prob_shift=0.05,
                       max_misalign=1,
                       subsample=8) + SimpleAugment() +
        IntensityAugment(0.9, 1.1, -0.1, 0.1) + IntensityScaleShift(2, -1) +
        ZeroOutConstSections() + GrowBoundary(steps=2) +
        SplitAndRenumberSegmentationLabels() +
        AddGtAffinities(malis.mknhood3d()) +
        BalanceLabels({VolumeTypes.GT_AFFINITIES: VolumeTypes.GT_SCALE},
                      {VolumeTypes.GT_AFFINITIES: VolumeTypes.GT_MASK}) +
        PreCache(cache_size=40, num_workers=10) +
        #DefectAugment(
        #    prob_missing=0.03,
        #    prob_low_contrast=0.01,
        #    prob_artifact=0.03,
        #    artifact_source=artifact_source,
        #    contrast_scale=0.5) +
        Train('unet_adapt',
              optimizer=net_io_names['optimizer'],
              loss=net_io_names['loss'],
              inputs={
                  net_io_names['raw']: VolumeTypes.RAW,
                  net_io_names['gt_affs']: VolumeTypes.GT_AFFINITIES,
                  net_io_names['loss_weights']: VolumeTypes.GT_SCALE
              },
              outputs={net_io_names['affs']: VolumeTypes.PREDICTED_AFFS},
              gradients={net_io_names['affs']: VolumeTypes.LOSS_GRADIENT}) +
        Snapshot(
            {
                VolumeTypes.RAW: 'volumes/raw',
                VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids',
                VolumeTypes.GT_AFFINITIES: 'volumes/labels/affinities',
                VolumeTypes.PREDICTED_AFFS: 'volumes/labels/pred_affinities',
                VolumeTypes.LOSS_GRADIENT: 'volumes/loss_gradient',
            },
            every=1000,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request) +
        PrintProfilingStats(every=5000))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)
    print("Training finished")
Beispiel #20
0
def train_until(max_iteration, data_sources):
    ArrayKey("RAW")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("GT_MASK")
    ArrayKey("GT_SCALE")
    ArrayKey("GT_AFFINITIES")
    ArrayKey("PREDICTED_AFFS")
    ArrayKey("LOSS_GRADIENT")

    data_providers = []
    fib25_dir = "/groups/saalfeld/saalfeldlab/larissa/data/gunpowder/fib25/"
    if "fib25h5" in data_sources:

        for volume_name in (
            "tstvol-520-1",
            "tstvol-520-2",
            "trvol-250-1",
            "trvol-250-2",
        ):
            h5_source = Hdf5Source(
                os.path.join(fib25_dir, volume_name + ".hdf"),
                datasets={
                    ArrayKeys.RAW: "volumes/raw",
                    ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids",
                    ArrayKeys.GT_MASK: "volumes/labels/mask",
                },
                array_specs={ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)},
            )
            data_providers.append(h5_source)

    fib19_dir = "/groups/saalfeld/saalfeldlab/larissa/fib19"
    # if 'fib19h5' in data_sources:
    #    for volume_name in ("trvol-250", "trvol-600"):
    #        h5_source = prepare_h5source(fib19_dir, volume_name)
    #        data_providers.append(h5_source)

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((8, 8, 8))
    input_size = Coordinate((132,) * 3) * voxel_size
    output_size = Coordinate((44,) * 3) * voxel_size

    # specifiy which volumes should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_MASK, input_size)
    request.add(ArrayKeys.GT_SCALE, output_size)
    request.add(ArrayKeys.GT_AFFINITIES, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None)
        + Pad(ArrayKeys.GT_MASK, None)
        + RandomLocation()
        + Reject(  # chose a random location inside the provided volumes
            ArrayKeys.GT_MASK
        )  # reject batches wich do contain less than 50% labelled data
        for provider in data_providers
    )

    snapshot_request = BatchRequest(
        {
            ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_LABELS],
            ArrayKeys.PREDICTED_AFFS: request[ArrayKeys.GT_LABELS],
            ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_AFFINITIES],
        }
    )

    # artifact_source = (
    #    Hdf5Source(
    #        os.path.join(data_dir, 'sample_ABC_padded_20160501.defects.hdf'),
    #        datasets = {
    #            VolumeTypes.RAW: 'defect_sections/raw',
    #            VolumeTypes.ALPHA_MASK: 'defect_sections/mask',
    #        },
    #        volume_specs = {
    #            VolumeTypes.RAW: VolumeSpec(voxel_size=(40, 4, 4)),
    #            VolumeTypes.ALPHA_MASK: VolumeSpec(voxel_size=(40, 4, 4)),
    #        }
    #    ) +
    #    RandomLocation(min_masked=0.05, mask_volume_type=VolumeTypes.ALPHA_MASK) +
    #    Normalize() +
    #    IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
    #    ElasticAugment([4,40,40], [0,2,2], [0,math.pi/2.0], subsample=8) +
    #    SimpleAugment(transpose_only_xy=True)
    # )

    train_pipeline = (
        data_sources
        + RandomProvider()
        + ElasticAugment(
            [40, 40, 40],
            [2, 2, 2],
            [0, math.pi / 2.0],
            prob_slip=0.01,
            prob_shift=0.05,
            max_misalign=1,
            subsample=8,
        )
        + SimpleAugment()
        + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1)
        + IntensityScaleShift(ArrayKeys.RAW, 2, -1)
        + ZeroOutConstSections(ArrayKeys.RAW)
        + GrowBoundary(ArrayKeys.GT_LABELS, mask=ArrayKeys.GT_MASK, steps=2)
        + RenumberConnectedComponents(ArrayKeys.GT_LABELS)
        + AddAffinities(malis.mknhood3d(), ArrayKeys.GT_LABELS, ArrayKeys.GT_AFFINITIES)
        + BalanceLabels(ArrayKeys.GT_AFFINITIES, ArrayKeys.GT_SCALE)
        + PreCache(cache_size=40, num_workers=10)
        +
        # DefectAugment(
        #    prob_missing=0.03,
        #    prob_low_contrast=0.01,
        #    prob_artifact=0.03,
        #    artifact_source=artifact_source,
        #    contrast_scale=0.5) +
        Train(
            "unet_auto",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names["loss"],
            inputs={
                net_io_names["raw"]: ArrayKeys.RAW,
                net_io_names["pred"]: ArrayKeys.GT_MASK,
                net_io_names["gt_affs"]: ArrayKeys.GT_AFFINITIES,
                net_io_names["loss_weights"]: ArrayKeys.GT_SCALE,
            },
            outputs={net_io_names["affs"]: ArrayKeys.PREDICTED_AFFS},
            gradients={net_io_names["affs"]: ArrayKeys.LOSS_GRADIENT},
        )
        + Snapshot(
            {
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids",
                ArrayKeys.GT_AFFINITIES: "volumes/labels/affinities",
                ArrayKeys.PREDICTED_AFFS: "volumes/labels/pred_affinities",
                ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient",
            },
            every=1000,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        )
        + PrintProfilingStats(every=5000)
    )

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)
    print("Training finished")
Beispiel #21
0
 def __init__(self, shape, weights, schedule=None, counter=0):
     self.nhood = malis.mknhood3d()
     self.shape = shape
Beispiel #22
0
def train_until(max_iteration):

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    register_volume_type('RAW')
    register_volume_type('ALPHA_MASK')
    register_volume_type('GT_LABELS')
    register_volume_type('GT_MASK')
    register_volume_type('GT_SCALE')
    register_volume_type('GT_AFFINITIES')
    register_volume_type('PREDICTED_AFFS')
    register_volume_type('LOSS_GRADIENT')

    input_size = Coordinate((84, 268, 268)) * (40, 4, 4)
    output_size = Coordinate((56, 56, 56)) * (40, 4, 4)

    request = BatchRequest()
    request.add(VolumeTypes.RAW, input_size)
    request.add(VolumeTypes.GT_LABELS, output_size)
    request.add(VolumeTypes.GT_MASK, output_size)
    request.add(VolumeTypes.GT_SCALE, output_size)
    request.add(VolumeTypes.GT_AFFINITIES, output_size)

    snapshot_request = BatchRequest({
        VolumeTypes.PREDICTED_AFFS:
        request[VolumeTypes.GT_AFFINITIES],
        VolumeTypes.LOSS_GRADIENT:
        request[VolumeTypes.GT_AFFINITIES]
    })

    data_sources = tuple(
        Hdf5Source(os.path.join(data_dir, sample + '.hdf'),
                   datasets={
                       VolumeTypes.RAW: 'volumes/raw',
                       VolumeTypes.GT_LABELS:
                       'volumes/labels/neuron_ids_notransparency',
                       VolumeTypes.GT_MASK: 'volumes/labels/mask',
                   }) + Normalize() + RandomLocation() + Reject()
        for sample in samples)

    artifact_source = (
        Hdf5Source(
            os.path.join(data_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                VolumeTypes.RAW: 'defect_sections/raw',
                VolumeTypes.ALPHA_MASK: 'defect_sections/mask',
            },
            volume_specs={
                VolumeTypes.RAW: VolumeSpec(voxel_size=(40, 4, 4)),
                VolumeTypes.ALPHA_MASK: VolumeSpec(voxel_size=(40, 4, 4)),
            }) + RandomLocation(min_masked=0.05,
                                mask_volume_type=VolumeTypes.ALPHA_MASK) +
        Normalize() +
        IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment([4, 40, 40], [0, 2, 2], [0, math.pi / 2.0],
                       subsample=8) + SimpleAugment(transpose_only_xy=True))

    train_pipeline = (
        data_sources + RandomProvider() +
        ElasticAugment([4, 40, 40], [0, 2, 2], [0, math.pi / 2.0],
                       prob_slip=0.05,
                       prob_shift=0.05,
                       max_misalign=10,
                       subsample=8) + SimpleAugment(transpose_only_xy=True) +
        GrowBoundary(steps=4, only_xy=True) +
        SplitAndRenumberSegmentationLabels() +
        AddGtAffinities(malis.mknhood3d()) +
        BalanceLabels({VolumeTypes.GT_AFFINITIES: VolumeTypes.GT_SCALE},
                      {VolumeTypes.GT_AFFINITIES: VolumeTypes.GT_MASK}) +
        IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        DefectAugment(prob_missing=0.03,
                      prob_low_contrast=0.01,
                      prob_artifact=0.03,
                      artifact_source=artifact_source,
                      contrast_scale=0.5) + IntensityScaleShift(2, -1) +
        ZeroOutConstSections() + PreCache(cache_size=40, num_workers=10) +
        Train('unet',
              optimizer=net_io_names['optimizer'],
              loss=net_io_names['loss'],
              inputs={
                  net_io_names['raw']: VolumeTypes.RAW,
                  net_io_names['gt_affs']: VolumeTypes.GT_AFFINITIES,
                  net_io_names['loss_weights']: VolumeTypes.GT_SCALE
              },
              outputs={net_io_names['affs']: VolumeTypes.PREDICTED_AFFS},
              gradients={net_io_names['affs']: VolumeTypes.LOSS_GRADIENT}) +
        Snapshot(
            {
                VolumeTypes.RAW: 'volumes/raw',
                VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids',
                VolumeTypes.GT_AFFINITIES: 'volumes/labels/affinities',
                VolumeTypes.PREDICTED_AFFS: 'volumes/labels/pred_affinities',
                VolumeTypes.LOSS_GRADIENT: 'volumes/loss_gradient',
            },
            every=100,
            output_filename='batch_{iteration}.hdf',
            additional_request=snapshot_request) +
        PrintProfilingStats(every=10))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)
    print("Training finished")
Beispiel #23
0
def train_until(max_iteration):

    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= max_iteration:
        return

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    register_volume_type('RAW')
    register_volume_type('GT_LABELS')
    register_volume_type('GT_MASK')
    register_volume_type('GT_AFFINITIES')
    register_volume_type('GT_AFFINITIES_MASK')
    register_volume_type('GT_AFFINITIES_SCALE')
    register_volume_type('PREDICTED_AFFS')
    register_volume_type('LOSS_GRADIENT')

    input_size = Coordinate((29, 188, 188))
    output_size = Coordinate((1, 100, 100))

    request = BatchRequest()
    request.add(VolumeTypes.RAW, input_size)
    request.add(VolumeTypes.GT_LABELS, output_size)
    request.add(VolumeTypes.GT_MASK, output_size)
    request.add(VolumeTypes.GT_AFFINITIES, output_size)
    request.add(VolumeTypes.GT_AFFINITIES_MASK, output_size)
    request.add(VolumeTypes.GT_AFFINITIES_SCALE, output_size)

    snapshot_request = BatchRequest({
        VolumeTypes.PREDICTED_AFFS:
        request[VolumeTypes.GT_AFFINITIES],
        VolumeTypes.LOSS_GRADIENT:
        request[VolumeTypes.GT_AFFINITIES]
    })

    data_sources = tuple(
        Hdf5Source(os.path.join(data_dir, sample + '.hdf'),
                   datasets={
                       VolumeTypes.RAW: 'volumes/raw',
                       VolumeTypes.GT_LABELS: 'volumes/labels/lineages',
                       VolumeTypes.GT_MASK: 'volumes/labels/ignore',
                   },
                   volume_specs={
                       VolumeTypes.GT_MASK: VolumeSpec(interpolatable=False)
                   }) + IgnoreToMask() + Normalize() + RandomLocation() +
        Reject() for sample in samples)

    train_pipeline = (
        data_sources + RandomProvider() +
        SplitAndRenumberSegmentationLabels() +
        GrowBoundary(steps=1, only_xy=True) +
        ElasticAugment([1, 10, 10], [0, 1, 1], [0, math.pi / 2.0],
                       prob_slip=0.05,
                       prob_shift=0.05,
                       max_misalign=3,
                       subsample=8) + SimpleAugment(transpose_only_xy=True) +
        AddGtAffinities(malis.mknhood3d(), gt_labels_mask=VolumeTypes.GT_MASK)
        + BalanceLabels(labels=VolumeTypes.GT_AFFINITIES,
                        scales=VolumeTypes.GT_AFFINITIES_SCALE,
                        mask=VolumeTypes.GT_AFFINITIES_MASK) +
        IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        IntensityScaleShift(2, -1) + PreCache(cache_size=40, num_workers=10) +
        Train('unet',
              optimizer=net_io_names['optimizer'],
              loss=net_io_names['loss'],
              inputs={
                  net_io_names['raw']: VolumeTypes.RAW,
                  net_io_names['gt_affs']: VolumeTypes.GT_AFFINITIES,
                  net_io_names['loss_weights']:
                  VolumeTypes.GT_AFFINITIES_SCALE,
              },
              outputs={net_io_names['affs']: VolumeTypes.PREDICTED_AFFS},
              gradients={net_io_names['affs']: VolumeTypes.LOSS_GRADIENT}) +
        IntensityScaleShift(0.5, 0.5) + Snapshot(
            {
                VolumeTypes.RAW: 'volumes/raw',
                VolumeTypes.GT_LABELS: 'volumes/labels/lineages',
                VolumeTypes.GT_MASK: 'volumes/labels/mask',
                VolumeTypes.GT_AFFINITIES: 'volumes/labels/affinities',
                VolumeTypes.PREDICTED_AFFS: 'volumes/labels/pred_affinities',
                VolumeTypes.LOSS_GRADIENT: 'volumes/loss_gradient',
            },
            every=100,
            output_filename='batch_{iteration}.hdf',
            additional_request=snapshot_request) +
        PrintProfilingStats(every=10))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration - trained_until):
            b.request_batch(request)
    print("Training finished")
Beispiel #24
0
def train_until(max_iteration, data_sources):

    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= max_iteration:
        return

    ArrayKey('RAW')
    ArrayKey('GT_LABELS')
    ArrayKey('GT_MASK')
    ArrayKey('GT_AFFINITIES')
    ArrayKey('GT_SCALE')
    ArrayKey('PREDICTED_AFFS_1')
    ArrayKey('PREDICTED_AFFS_2')
    ArrayKey('LOSS_GRADIENT_1')
    ArrayKey('LOSS_GRADIENT_2')

    data_providers = []
    fib25_dir = "/groups/saalfeld/saalfeldlab/larissa/data/gunpowder/fib25/"
    if 'fib25h5' in data_sources:

        for volume_name in ("tstvol-520-1", "tstvol-520-2", "trvol-250-1",
                            "trvol-250-2"):
            h5_source = Hdf5Source(os.path.join(fib25_dir,
                                                volume_name + '.hdf'),
                                   datasets={
                                       ArrayKeys.RAW: 'volumes/raw',
                                       ArrayKeys.GT_LABELS:
                                       'volumes/labels/neuron_ids',
                                       ArrayKeys.GT_MASK:
                                       'volumes/labels/mask',
                                   },
                                   array_specs={
                                       ArrayKeys.GT_MASK:
                                       ArraySpec(interpolatable=False)
                                   })
            data_providers.append(h5_source)

    fib19_dir = "/groups/saalfeld/saalfeldlab/larissa/fib19"
    #   if 'fib19h5' in data_sources:
    #       for volume_name in ("trvol-250", "trvol-600"):
    #           h5_source = prepare_h5source(fib19_dir, volume_name)
    #           data_providers.append(h5_source)

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((8, 8, 8))
    input_size = Coordinate((220, ) * 3) * voxel_size
    output_1_size = Coordinate((132, ) * 3) * voxel_size
    output_2_size = Coordinate((44, ) * 3) * voxel_size

    #input_size = Coordinate((66, 228, 228))*(40,4,4)
    #output_1_size = Coordinate((38, 140, 140))*(40,4,4)
    #output_2_size = Coordinate((10, 52, 52))*(40,4,4)

    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_1_size)
    request.add(ArrayKeys.GT_MASK, output_1_size)
    request.add(ArrayKeys.GT_AFFINITIES, output_1_size)
    request.add(ArrayKeys.GT_SCALE, output_1_size)

    snapshot_request = BatchRequest()
    snapshot_request.add(ArrayKeys.RAW,
                         input_size)  # just to center the rest correctly
    snapshot_request.add(ArrayKeys.PREDICTED_AFFS_1, output_1_size)
    snapshot_request.add(ArrayKeys.PREDICTED_AFFS_2, output_2_size)
    snapshot_request.add(ArrayKeys.LOSS_GRADIENT_1, output_1_size)
    snapshot_request.add(ArrayKeys.LOSS_GRADIENT_2, output_2_size)

    data_sources = tuple(
        provider + Normalize(ArrayKeys.RAW) + Pad(ArrayKeys.RAW, None) +
        Pad(ArrayKeys.GT_MASK, None) + RandomLocation() +
        Reject(ArrayKeys.GT_MASK) for provider in data_providers)

    train_pipeline = (
        data_sources + RandomProvider() +
        ElasticAugment([40, 40, 40], [2, 2, 2], [0, math.pi / 2.0],
                       prob_slip=0.01,
                       prob_shift=0.05,
                       max_misalign=1,
                       subsample=8) + SimpleAugment() +
        IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) +
        GrowBoundary(ArrayKeys.GT_LABELS, mask=ArrayKeys.GT_MASK, steps=2) +
        RenumberConnectedComponents(ArrayKeys.GT_LABELS) + AddAffinities(
            malis.mknhood3d(), ArrayKeys.GT_LABELS, ArrayKeys.GT_AFFINITIES) +
        BalanceLabels(ArrayKeys.GT_AFFINITIES, ArrayKeys.GT_SCALE) +
        PreCache(cache_size=40, num_workers=10) +
        Train('wnet',
              optimizer=net_io_names['optimizer'],
              loss=net_io_names['loss'],
              summary=net_io_names['summary'],
              inputs={
                  net_io_names['raw']: ArrayKeys.RAW,
                  net_io_names['gt_affs']: ArrayKeys.GT_AFFINITIES,
                  net_io_names['loss_weights']: ArrayKeys.GT_SCALE,
              },
              outputs={
                  net_io_names['affs_1']: ArrayKeys.PREDICTED_AFFS_1,
                  net_io_names['affs_2']: ArrayKeys.PREDICTED_AFFS_2
              },
              gradients={
                  net_io_names['affs_1']: ArrayKeys.LOSS_GRADIENT_1,
                  net_io_names['affs_2']: ArrayKeys.LOSS_GRADIENT_2
              }) + IntensityScaleShift(ArrayKeys.RAW, 0.5, 0.5) +
        Snapshot(
            {
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_LABELS: 'volumes/labels/neuron_ids',
                ArrayKeys.GT_AFFINITIES: 'volumes/labels/affinities',
                ArrayKeys.PREDICTED_AFFS_1: 'volumes/labels/pred_affinities_1',
                ArrayKeys.PREDICTED_AFFS_2: 'volumes/labels/pred_affinities_2',
                ArrayKeys.LOSS_GRADIENT_1: 'volumes/loss_gradient_1',
                ArrayKeys.LOSS_GRADIENT_2: 'volumes/loss_gradient_2',
            },
            every=500,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request) +
        PrintProfilingStats(every=1000))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration - trained_until):
            b.request_batch(request)
    print("Training finished")
Beispiel #25
0
training_hist = model.fit(data,
                        gt,
                        batch_size=3,
                        nb_epoch=n_epochs,
                        verbose=0)
plt.figure()
plt.plot(training_hist.history['loss'])
plt.xlabel("epochs")
plt.ylabel("training loss")



# predict an affinity graph and compare it with the affinity graph
# created by the true segmentation
plot_sample = 1
from malis import mknhood3d, seg_to_affgraph
pred_aff = model.predict(data)[plot_sample]
aff = seg_to_affgraph(gt[plot_sample,0], mknhood3d())
plt.figure()
plt.subplot(131)
plt.pcolor(data[plot_sample,0,1], cmap="gray")
plt.title("data")
plt.subplot(132)
plt.pcolor(aff[1,1], cmap="gray")
plt.title("aff from gt")
plt.subplot(133)
plt.pcolor(pred_aff[1,1], cmap="gray")
plt.title("predicted aff")

plt.show()
Beispiel #26
0
def predict(roi_synapses, voxel_size):

    # get network and weights to use for inference
    prototxt = 'net.prototxt'
    weights = 'net_iter_200000.caffemodel'

    # create template for chunk
    chunk_spec_template = BatchRequest()
    shape_input_template = [132, 132, 132] * np.asarray(voxel_size)
    shape_output_template = [44, 44, 44] * np.asarray(voxel_size)
    chunk_spec_template.add_volume_request(VolumeTypes.RAW,
                                           shape_input_template)
    chunk_spec_template.add_volume_request(VolumeTypes.GT_LABELS,
                                           shape_output_template)
    chunk_spec_template.add_volume_request(VolumeTypes.GT_AFFINITIES,
                                           shape_output_template)
    chunk_spec_template.add_volume_request(VolumeTypes.PRED_AFFINITIES,
                                           shape_output_template)
    chunk_spec_template.add_volume_request(VolumeTypes.GT_BM_PRESYN,
                                           shape_output_template)
    chunk_spec_template.add_volume_request(
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN, shape_output_template)
    chunk_spec_template.add_volume_request(VolumeTypes.PRED_BM_PRESYN,
                                           shape_output_template)
    chunk_spec_template.add_points_request(PointsTypes.PRESYN,
                                           shape_output_template)
    chunk_spec_template.add_volume_request(VolumeTypes.GT_BM_POSTSYN,
                                           shape_output_template)
    chunk_spec_template.add_volume_request(
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN, shape_output_template)
    chunk_spec_template.add_volume_request(VolumeTypes.PRED_BM_POSTSYN,
                                           shape_output_template)
    chunk_spec_template.add_points_request(PointsTypes.POSTSYN,
                                           shape_output_template)

    # create batch request, shapes and Type of requests
    request = BatchRequest()
    shape_outputs = roi_synapses.get_shape()
    request.add_volume_request(VolumeTypes.RAW, shape_outputs)
    request.add_volume_request(VolumeTypes.GT_LABELS, shape_outputs)
    request.add_volume_request(VolumeTypes.GT_AFFINITIES, shape_outputs)
    request.add_volume_request(VolumeTypes.PRED_AFFINITIES, shape_outputs)
    request.add_volume_request(VolumeTypes.GT_BM_PRESYN, shape_outputs)
    request.add_volume_request(VolumeTypes.PRED_BM_PRESYN, shape_outputs)
    request.add_points_request(PointsTypes.PRESYN, shape_outputs)
    request.add_volume_request(VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN,
                               shape_outputs)
    request.add_volume_request(VolumeTypes.GT_BM_POSTSYN, shape_outputs)
    request.add_volume_request(VolumeTypes.PRED_BM_POSTSYN, shape_outputs)
    request.add_points_request(PointsTypes.POSTSYN, shape_outputs)
    request.add_volume_request(VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN,
                               shape_outputs)

    # shift request roi to correct offset
    request_offset = roi_synapses.get_offset()
    for request_type in [request.volumes, request.points]:
        for type in request_type:
            request_type[type] += request_offset

    # set network inputs, outputs and resolutions of output volumes
    net_inputs = {VolumeTypes.RAW: 'data'}
    net_outputs = {
        VolumeTypes.PRED_BM_PRESYN: 'bm_presyn_pred',
        VolumeTypes.PRED_BM_POSTSYN: 'bm_postsyn_pred',
        VolumeTypes.PRED_AFFINITIES: 'aff_pred'
    }
    output_resolutions = {
        VolumeTypes.PRED_BM_PRESYN: voxel_size,
        VolumeTypes.PRED_BM_POSTSYN: voxel_size,
        VolumeTypes.PRED_AFFINITIES: voxel_size
    }

    # define settings binary mask (rasterization of synapse points to binary mask)
    volumetypes_to_pointstype = {
        VolumeTypes.GT_BM_PRESYN: PointsTypes.PRESYN,
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN: PointsTypes.PRESYN,
        VolumeTypes.GT_BM_POSTSYN: PointsTypes.POSTSYN,
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN: PointsTypes.POSTSYN
    }
    rastersetting_bm = RasterizationSetting(marker_size_physical=32)
    rastersetting_mask_syn = RasterizationSetting(marker_size_physical=96,
                                                  donut_inner_radius=32,
                                                  invert_map=True)
    volumetype_to_rastersettings = {
        VolumeTypes.GT_BM_PRESYN: rastersetting_bm,
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN: rastersetting_mask_syn,
        VolumeTypes.GT_BM_POSTSYN: rastersetting_bm,
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN: rastersetting_mask_syn,
    }

    # define names of datasets in snapshot file
    snapshot_dataset_names = {
        VolumeTypes.RAW: 'volumes/raw',
        VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids',
        VolumeTypes.GT_AFFINITIES: 'volumes/labels/affs',
        VolumeTypes.PRED_AFFINITIES: 'volumes/predicted_affs',
        VolumeTypes.LOSS_SCALE: 'volumes/loss_scale',
        VolumeTypes.LOSS_GRADIENT: 'volumes/predicted_affs_loss_gradient',
        VolumeTypes.GT_BM_PRESYN: 'volumes/labels/gt_bm_presyn',
        VolumeTypes.GT_BM_POSTSYN: 'volumes/labels/gt_bm_postsyn',
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_PRESYN:
        'volumes/labels/gt_mask_exclusivezone_presyn',
        VolumeTypes.GT_MASK_EXCLUSIVEZONE_POSTSYN:
        'volumes/labels/gt_mask_exclusivezone_postsyn',
        VolumeTypes.PRED_BM_PRESYN: 'volumes/predicted_bm_presyn',
        VolumeTypes.PRED_BM_POSTSYN: 'volumes/predicted_bm_postsyn',
        VolumeTypes.LOSS_SCALE_BM_PRESYN: 'volumes/loss_scale_presyn',
        VolumeTypes.LOSS_SCALE_BM_POSTSYN: 'volumes/loss_scale_postsyn',
    }

    # set source of data
    data_source = (
        DvidSource(hostname='slowpoke2',
                   port=8000,
                   uuid='cb7dc',
                   volume_array_names={
                       VolumeTypes.RAW: 'grayscale',
                       VolumeTypes.GT_LABELS: 'labels'
                   },
                   points_array_names={
                       PointsTypes.PRESYN: 'combined_synapses_08302016',
                       PointsTypes.POSTSYN: 'combined_synapses_08302016'
                   },
                   points_rois={
                       PointsTypes.PRESYN:
                       roi_synapses.grow((0, 0, 0), Coordinate(
                           (352, 352, 352))),
                       PointsTypes.POSTSYN:
                       roi_synapses.grow((0, 0, 0), Coordinate(
                           (352, 352, 352)))
                   },
                   points_voxel_size={
                       PointsTypes.PRESYN: voxel_size,
                       PointsTypes.POSTSYN: voxel_size
                   }) + Pad({VolumeTypes.RAW: Coordinate(
                       (704, 704, 704))}, {VolumeTypes.RAW: 255}) +
        Normalize())

    # define pipeline to process chunk
    batch_provider_tree = (data_source + RasterizePoints(
        volumetypes_to_pointstype=volumetypes_to_pointstype,
        volumetypes_to_rastersettings=volumetype_to_rastersettings) +
                           AddGtAffinities(malis.mknhood3d()) +
                           IntensityScaleShift(2, -1) +
                           ZeroOutConstSections() + Predict(prototxt,
                                                            weights,
                                                            net_inputs,
                                                            net_outputs,
                                                            output_resolutions,
                                                            use_gpu=0) +
                           Chunk(chunk_spec_template, num_workers=1) +
                           Snapshot(dataset_names=snapshot_dataset_names,
                                    every=1,
                                    output_filename='output',
                                    compression_type="gzip"))

    # request a "batch" of the size of the whole dataset
    with build(batch_provider_tree) as minibatch_maker:
        minibatch_maker.request_batch(request)
    print("Finished")
import matplotlib.pyplot as plt
from img_aug_func import *

from tensorpack.dataflow import (AugmentImageComponent, PrefetchDataZMQ,
                                 BatchData, MultiThreadMapData)

import malis
from funcs import *
import funcs
from funcs import *
from malis_loss import *
import time
# import tf_learn_func
#######################################################################################
input_shape = (16, 256, 256)
nhood = malis.mknhood3d(1)
affs_shape = (len(nhood), ) + input_shape
NB_FILTERS = 32

tf_nhood = tf.constant(nhood)
#######################################################################################

np.random.seed(999)


class MyDataFlow(DataFlow):
    def __init__(self, set_type, X, y):
        self.set_type = set_type
        self.volume = X
        self.gt_seg = y
Beispiel #28
0
                new_dname = '{dname}_z{z}_y{y}_x{x}_xy{swapxy}_angle{angle:05.1f}'.format(
                    dname=dname,
                    z=reflectz,
                    y=reflecty,
                    x=reflectx,
                    swapxy=swapxy,
                    angle=angle)
                names.append(new_dname)
base_dir = '/data/SNEMI_aug'  #'/nobackup/turaga/singhc/SNEMI_aug/rotations'
for name in names:
    dataset = dict()
    dataset['name'] = name
    dataset['data'] = h5py.File(os.path.join(base_dir, 'raw.h5'), 'r')[name]
    dataset['components'] = h5py.File(os.path.join(base_dir, 'seg.h5'),
                                      'r')[name]
    dataset['mask'] = h5py.File(os.path.join(base_dir, 'mask.h5'), 'r')[name]
    train_dataset.append(dataset)

for dataset in train_dataset:
    dataset['nhood'] = malis.mknhood3d()
    dataset['mask_threshold'] = mask_threshold
    dataset['mask_dilation_steps'] = mask_dilation_steps
    dataset['transform'] = {}
    dataset['transform']['scale'] = (0.8, 1.2)
    dataset['transform']['shift'] = (-0.2, 0.2)

print('Training set contains', str(len(train_dataset)), 'volumes:',
      [dataset['name'] for dataset in train_dataset], train_dataset)

## Testing datasets
test_dataset = []
Beispiel #29
0
def train(solver, data_arrays, label_arrays, mode='malis'):
    losses = []

    net = solver.net
    if mode == 'malis':
        nhood = malis.mknhood3d()
    if mode == 'euclid':
        nhood = malis.mknhood3d()
    if mode == 'malis_aniso':
        nhood = malis.mknhood3d_aniso()
    if mode == 'euclid_aniso':
        nhood = malis.mknhood3d_aniso()

    data_slice_cont = np.zeros((1, 1, 132, 132, 132), dtype=float32)
    label_slice_cont = np.zeros((1, 1, 44, 44, 44), dtype=float32)
    aff_slice_cont = np.zeros((1, 3, 44, 44, 44), dtype=float32)
    nhood_cont = np.zeros((1, 1, 3, 3), dtype=float32)
    error_scale_cont = np.zeros((1, 1, 44, 44, 44), dtype=float32)

    dummy_slice = np.ascontiguousarray([0]).astype(float32)

    # Loop from current iteration to last iteration
    for i in range(solver.iter, solver.max_iter):

        # First pick the dataset to train with
        dataset = randint(0, len(data_arrays) - 1)
        data_array = data_arrays[dataset]
        label_array = label_arrays[dataset]
        # affinity_array = affinity_arrays[dataset]

        offsets = []
        for j in range(0, dims):
            offsets.append(
                randint(
                    0, data_array.shape[j] -
                    (config.output_dims[j] + config.input_padding[j])))

        # These are the raw data elements
        data_slice = slice_data(data_array, offsets, [
            config.output_dims[di] + config.input_padding[di]
            for di in range(0, dims)
        ])

        # These are the labels (connected components)
        label_slice = slice_data(label_array, [
            offsets[di] + int(math.ceil(config.input_padding[di] / float(2)))
            for di in range(0, dims)
        ], config.output_dims)

        # These are the affinity edge values
        # Also recomputing the corresponding labels (connected components)
        aff_slice = malis.seg_to_affgraph(label_slice, nhood)
        label_slice, ccSizes = malis.connected_components_affgraph(
            aff_slice, nhood)

        print(data_slice[None, None, :]).shape
        print(label_slice[None, None, :]).shape
        print(aff_slice[None, :]).shape
        print(nhood).shape

        if mode == 'malis':
            np.copyto(
                data_slice_cont,
                np.ascontiguousarray(data_slice[None,
                                                None, :]).astype(float32))
            np.copyto(
                label_slice_cont,
                np.ascontiguousarray(label_slice[None,
                                                 None, :]).astype(float32))
            np.copyto(aff_slice_cont,
                      np.ascontiguousarray(aff_slice[None, :]).astype(float32))
            np.copyto(
                nhood_cont,
                np.ascontiguousarray(nhood[None, None, :]).astype(float32))

            net.set_input_arrays(0, data_slice_cont, dummy_slice)
            net.set_input_arrays(1, label_slice_cont, dummy_slice)
            net.set_input_arrays(2, aff_slice_cont, dummy_slice)
            net.set_input_arrays(3, nhood_cont, dummy_slice)

        # We pass the raw and affinity array only
        if mode == 'euclid':
            net.set_input_arrays(
                0,
                np.ascontiguousarray(data_slice[None,
                                                None, :]).astype(float32),
                np.ascontiguousarray(dummy_slice).astype(float32))
            net.set_input_arrays(
                1,
                np.ascontiguousarray(aff_slice[None, :]).astype(float32),
                np.ascontiguousarray(dummy_slice).astype(float32))
            net.set_input_arrays(
                2,
                np.ascontiguousarray(
                    error_scale(aff_slice[None, :], 1.0,
                                0.045)).astype(float32),
                np.ascontiguousarray(dummy_slice).astype(float32))

        if mode == 'softmax':
            net.set_input_arrays(
                0,
                np.ascontiguousarray(data_slice[None,
                                                None, :]).astype(float32),
                np.ascontiguousarray(dummy_slice).astype(float32))
            net.set_input_arrays(
                1,
                np.ascontiguousarray(label_slice[None,
                                                 None, :]).astype(float32),
                np.ascontiguousarray(dummy_slice).astype(float32))

        # Single step
        loss = solver.step(1)

        # Memory clean up and report
        print("Memory usage (before GC): %d MiB" %
              ((resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) / (1024)))

        while gc.collect():
            pass

        print("Memory usage (after GC): %d MiB" %
              ((resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) / (1024)))

        # m = volume_slicer.VolumeSlicer(data=np.squeeze((net.blobs['Convolution18'].data[0])[0,:,:]))
        # m.configure_traits()

        print("Loss: %s" % loss)
        losses += [loss]
Beispiel #30
0
import numpy as np
import h5py
import datetime
np.set_printoptions(precision=4)
import malis as m

print("Can we make the `nhood' for an isotropic 3d dataset")
print("corresponding to a 6-connected neighborhood?")
nhood = m.mknhood3d(1)
print(nhood)

print("Can we make the `nhood' for an anisotropic 3d dataset")
print("corresponding to a 4-connected neighborhood in-plane")
print("and 26-connected neighborhood in the previous z-plane?")
nhood = m.mknhood3d_aniso(1, 1.8)
print(nhood)

segTrue = np.array([0, 1, 1, 1, 2, 2, 0, 5, 5, 5, 5], dtype=np.uint64)
node1 = np.arange(segTrue.shape[0] - 1, dtype=np.uint64)
node2 = np.arange(1, segTrue.shape[0], dtype=np.uint64)
print("####################")
print(node1)
print(node2)

nVert = segTrue.shape[0]
edgeWeight = np.array([0, 1, 2, 0, 2, 0, 0, 1, 2, 2.5], dtype=np.float32)
edgeWeight = edgeWeight / edgeWeight.max()
print(segTrue)
print(edgeWeight)

nPairPos = m.malis_loss_weights(segTrue, node1, node2, edgeWeight, 1)
Beispiel #31
0
 def __init__(self, shape):
     self.nhood = malis.mknhood3d()
     self.shape = shape
def train(solver, data_arrays, label_arrays, mode='malis'):
    losses = []
    
    net = solver.net
    if mode == 'malis':
        nhood = malis.mknhood3d()
    if mode == 'euclid':
        nhood = malis.mknhood3d()
    if mode == 'malis_aniso':
        nhood = malis.mknhood3d_aniso()
    if mode == 'euclid_aniso':
        nhood = malis.mknhood3d_aniso()
        
    data_slice_cont = np.zeros((1,1,132,132,132), dtype=float32)
    label_slice_cont  = np.zeros((1,1,44,44,44), dtype=float32)
    aff_slice_cont = np.zeros((1,3,44,44,44), dtype=float32)
    nhood_cont = np.zeros((1,1,3,3), dtype=float32)
    error_scale_cont = np.zeros((1,1,44,44,44), dtype=float32)
    
    dummy_slice = np.ascontiguousarray([0]).astype(float32)
    
    # Loop from current iteration to last iteration
    for i in range(solver.iter, solver.max_iter):
        
        # First pick the dataset to train with
        dataset = randint(0, len(data_arrays) - 1)
        data_array = data_arrays[dataset]
        label_array = label_arrays[dataset]
        # affinity_array = affinity_arrays[dataset]
        
        offsets = []
        for j in range(0, dims):
            offsets.append(randint(0, data_array.shape[j] - (config.output_dims[j] + config.input_padding[j])))
        
        
        # These are the raw data elements
        data_slice = slice_data(data_array, offsets, [config.output_dims[di] + config.input_padding[di] for di in range(0, dims)])
        
        # These are the labels (connected components)
        label_slice = slice_data(label_array, [offsets[di] + int(math.ceil(config.input_padding[di] / float(2))) for di in range(0, dims)], config.output_dims)
        
        # These are the affinity edge values
        # Also recomputing the corresponding labels (connected components)
        aff_slice = malis.seg_to_affgraph(label_slice,nhood)
        label_slice,ccSizes = malis.connected_components_affgraph(aff_slice,nhood)

        print (data_slice[None, None, :]).shape
        print (label_slice[None, None, :]).shape
        print (aff_slice[None, :]).shape
        print (nhood).shape
        
        if mode == 'malis':
            np.copyto(data_slice_cont, np.ascontiguousarray(data_slice[None, None, :]).astype(float32))
            np.copyto(label_slice_cont, np.ascontiguousarray(label_slice[None, None, :]).astype(float32))
            np.copyto(aff_slice_cont, np.ascontiguousarray(aff_slice[None, :]).astype(float32))
            np.copyto(nhood_cont, np.ascontiguousarray(nhood[None, None, :]).astype(float32))
            
            net.set_input_arrays(0, data_slice_cont, dummy_slice)
            net.set_input_arrays(1, label_slice_cont, dummy_slice)
            net.set_input_arrays(2, aff_slice_cont, dummy_slice)
            net.set_input_arrays(3, nhood_cont, dummy_slice)
            
        # We pass the raw and affinity array only
        if mode == 'euclid':
            net.set_input_arrays(0, np.ascontiguousarray(data_slice[None, None, :]).astype(float32), np.ascontiguousarray(dummy_slice).astype(float32))
            net.set_input_arrays(1, np.ascontiguousarray(aff_slice[None, :]).astype(float32), np.ascontiguousarray(dummy_slice).astype(float32))
            net.set_input_arrays(2, np.ascontiguousarray(error_scale(aff_slice[None, :],1.0,0.045)).astype(float32), np.ascontiguousarray(dummy_slice).astype(float32))

        if mode == 'softmax':
            net.set_input_arrays(0, np.ascontiguousarray(data_slice[None, None, :]).astype(float32), np.ascontiguousarray(dummy_slice).astype(float32))
            net.set_input_arrays(1, np.ascontiguousarray(label_slice[None, None, :]).astype(float32), np.ascontiguousarray(dummy_slice).astype(float32))
        
        # Single step
        loss = solver.step(1)

        # Memory clean up and report
        print("Memory usage (before GC): %d MiB" % ((resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) / (1024)))
        
        while gc.collect():
            pass

        print("Memory usage (after GC): %d MiB" % ((resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) / (1024)))


        # m = volume_slicer.VolumeSlicer(data=np.squeeze((net.blobs['Convolution18'].data[0])[0,:,:]))
        # m.configure_traits()

        print("Loss: %s" % loss)
        losses += [loss]
Beispiel #33
0
import numpy as np
import h5py
import datetime
np.set_printoptions(precision=4)
import malis as m


print "Can we make the `nhood' for an isotropic 3d dataset"
print "corresponding to a 6-connected neighborhood?"
nhood = m.mknhood3d(1)
print nhood

print "Can we make the `nhood' for an anisotropic 3d dataset"
print "corresponding to a 4-connected neighborhood in-plane"
print "and 26-connected neighborhood in the previous z-plane?"
nhood = m.mknhood3d_aniso(1,1.8)
print nhood

segTrue = np.array([0, 1, 1, 1, 2, 2, 0, 5, 5, 5, 5],dtype=np.uint64);
node1 = np.arange(segTrue.shape[0]-1,dtype=np.uint64)
node2 = np.arange(1,segTrue.shape[0],dtype=np.uint64)
nVert = segTrue.shape[0]
edgeWeight = np.array([0, 1, 2, 0, 2, 0, 0, 1, 2, 2.5],dtype=np.float32);
edgeWeight = edgeWeight/edgeWeight.max()
print segTrue
print edgeWeight

nPairPos = m.malis_loss_weights(segTrue, node1, node2, edgeWeight, 1)
nPairNeg = m.malis_loss_weights(segTrue, node1, node2, edgeWeight, 0)
print np.vstack((nPairPos,nPairNeg))
# print nPairNeg
Beispiel #34
0
def train_until(max_iteration, gpu, long_range=True):

    # get most recent training result
    solverstates = [int(f.split('.')[0].split('_')[-1]) for f in glob.glob('net_iter_*.solverstate')]
    if len(solverstates) > 0:
        trained_until = max(solverstates)
        print("Resuming training from iteration " + str(trained_until))
    else:
        trained_until = 0
        print("Starting fresh training")

    if trained_until < phase_switch and max_iteration > phase_switch:
        # phase switch lies in-between, split training into to parts
        train_until(phase_switch, gpu)
        trained_until = phase_switch

    if max_iteration <= phase_switch:
        phase = 'euclid'
    else:
        phase = 'malis'

    net_file = 'default_unet.prototxt' if not long_range else 'long_range_unet.prototxt'
    print()
    print("Training until " + str(max_iteration) + " in phase " + phase)
    print("Training with architecture from %s" % net_file)
    print("Using gpu: %i" % gpu)
    print()

    solver_parameters = SolverParameters()
    solver_parameters.train_net = net_file
    solver_parameters.base_lr = 0.5e-4
    solver_parameters.momentum = 0.95
    solver_parameters.momentum2 = 0.999
    solver_parameters.delta = 1e-8
    solver_parameters.weight_decay = 0.000005
    solver_parameters.lr_policy = 'inv'
    solver_parameters.gamma = 0.0001
    solver_parameters.power = 0.75
    solver_parameters.snapshot = 2000
    solver_parameters.snapshot_prefix = 'net'
    solver_parameters.type = 'Adam'
    if trained_until > 0:
        solver_parameters.resume_from = 'net_iter_' + str(trained_until) + '.solverstate'
    else:
        solver_parameters.resume_from = None
    solver_parameters.train_state.add_stage(phase)

    # register new volume type
    register_volume_type('MALIS_COMP_LABEL')
    register_volume_type('LOSS_SCALE')

    # make request with all volume types we need
    request = BatchRequest()
    request.add(VolumeTypes.RAW, Coordinate((84,268,268))*(40,4,4))
    request.add(VolumeTypes.GT_LABELS, Coordinate((56,56,56))*(40,4,4))
    request.add(VolumeTypes.GT_MASK, Coordinate((56,56,56))*(40,4,4))
    request.add(VolumeTypes.GT_AFFINITIES, Coordinate((56,56,56))*(40,4,4))
    request.add(VolumeTypes.LOSS_SCALE, Coordinate((56,56,56))*(40,4,4))
    request.add(VolumeTypes.MALIS_COMP_LABEL, Coordinate((56,56,56))*(40,4,4))

    # additional gradients for snapshots TODO add loss gradient
    additional_request = BatchRequest()
    additional_request.add(VolumeTypes.PRED_AFFINITIES, Coordinate((56,56,56))*(40,4,4))

    data_sources = tuple(
        Hdf5Source(
            os.path.join(data_dir, sample),
            datasets = {
                VolumeTypes.RAW: 'data',
                VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids_notransparency',
                VolumeTypes.GT_MASK: 'volumes/labels/mask',
            }
        ) +
        Normalize() +
        RandomLocation() +
        Reject()
        for sample in samples
    )

    artifact_source = (
        Hdf5Source(
            os.path.join(data_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets = {
                VolumeTypes.RAW: 'defect_sections/raw',
                VolumeTypes.ALPHA_MASK: 'defect_sections/mask',
            }
        ) +
        RandomLocation(min_masked=0.05, mask_volume_type=VolumeTypes.ALPHA_MASK) +
        Normalize() +
        IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment([4,40,40], [0,2,2], [0,math.pi/2.0], subsample=8) +
        SimpleAugment(transpose_only_xy=True)
    )

    nhood = malis.mknhood3d() if not long_range else make_long_range_nhood()
    print()
    print(nhood)
    print()

    train_pipeline = (
        data_sources +
        RandomProvider() +
        # first augmentations: add defect augmentations (only raw data)
        DefectAugment(
            prob_missing=0.03,
            prob_low_contrast=0.01,
            prob_artifact=0.03,
            prob_deform=0.03,
            artifact_source=artifact_source,
            contrast_scale=0.5) +
        # next augmentation: elastic + flips in xy
        ElasticAugment(
            [4,40,40], [0,2,2], [0,math.pi/2.0], prob_slip=0.05,prob_shift=0.05,max_misalign=10, subsample=8
        ) +
        SimpleAugment(transpose_only_xy=True) +
        # connected componets, grow boundaries and get affinities
        SplitAndRenumberSegmentationLabels() +
        GrowBoundary(
            steps=1, # we grow less for long range affinities
            only_xy=True) +
        AddGtAffinities(nhood) +
        # intensitiy augmentations and normalizations
        IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        IntensityScaleShift(2, -1) +
        ZeroOutConstSections() +
        # magic prepare malis node
        PrepareMalis() +
	# balance the labels
        BalanceLabels(labels_to_loss_scale_volume={VolumeTypes.GT_AFFINITIES: VolumeTypes.LOSS_SCALE},
                        labels_to_mask_volumes={VolumeTypes.GT_AFFINITIES: (VolumeTypes.GT_MASK,)}) +
	# run the actual traing
        PreCache(
            cache_size=40,
            num_workers=10) +
        Train(solver_parameters,
              inputs={VolumeTypes.RAW: 'data',
                      VolumeTypes.GT_AFFINITIES: 'aff_label',
                      VolumeTypes.LOSS_SCALE: 'scale',
                      VolumeTypes.MALIS_COMP_LABEL: 'comp_label',
                      'affinity_neighborhood': 'nhood'},
              outputs={VolumeTypes.PRED_AFFINITIES: 'aff_pred'},
              gradients={VolumeTypes.LOSS_GRADIENT: 'aff_pred'},
              use_gpu=gpu) +
        Snapshot(
            {
                VolumeTypes.RAW: 'volumes/raw',
                VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids',
                VolumeTypes.GT_MASK: 'volumes/labels/mask',
                VolumeTypes.GT_AFFINITIES: 'volumes/labels/affinities',
                VolumeTypes.PRED_AFFINITIES: 'volumes/labels/prediction'
            },
            every=100,
            output_filename='final_it={iteration}_id={id}.hdf',
            additional_request=additional_request) +
        PrintProfilingStats(every=100)
    )

    iterations = max_iteration - trained_until
    assert iterations >= 0

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(iterations):
            b.request_batch(request)
    print("Training finished")
Beispiel #35
0
              metrics=[
                  metr_pos_count, metr_neg_count, metr_max_pos_count,
                  metr_max_neg_count, metr_pos_cost, metr_neg_cost
              ])
hist = model.evaluate(data[[0]], gt[[0]])
pdb.set_trace()
training_hist = model.fit(data, gt, batch_size=3, nb_epoch=n_epochs, verbose=0)
plt.figure()
plt.plot(training_hist.history['loss'])
plt.xlabel("epochs")
plt.ylabel("training loss")

# predict an affinity graph and compare it with the affinity graph
# created by the true segmentation
plot_sample = 1
from malis import mknhood3d, seg_to_affgraph
pred_aff = model.predict(data)[plot_sample]
aff = seg_to_affgraph(gt[plot_sample, 0], mknhood3d())
plt.figure()
plt.subplot(131)
plt.pcolor(data[plot_sample, 0, 1], cmap="gray")
plt.title("data")
plt.subplot(132)
plt.pcolor(aff[1, 1], cmap="gray")
plt.title("aff from gt")
plt.subplot(133)
plt.pcolor(pred_aff[1, 1], cmap="gray")
plt.title("predicted aff")

plt.show()
Beispiel #36
0
def train_until(max_iteration, gpu):
    '''Resume training from the last stored network weights and train until ``max_iteration``.'''

    set_verbose(False)

    # get most recent training result
    solverstates = [ int(f.split('.')[0].split('_')[-1]) for f in glob.glob('net_iter_*.solverstate') ]
    if len(solverstates) > 0:
        trained_until = max(solverstates)
        print("Resuming training from iteration " + str(trained_until))
    else:
        trained_until = 0
        print("Starting fresh training")

    if trained_until < phase_switch and max_iteration > phase_switch:
        # phase switch lies in-between, split training into to parts
        train_until(phase_switch, gpu)
        trained_until = phase_switch

    if max_iteration <= phase_switch:
        phase = 'euclid'
    else:
        phase = 'malis'
    print("Traing until " + str(max_iteration) + " in phase " + phase)

    # setup training solver and network
    solver_parameters = SolverParameters()
    solver_parameters.train_net = 'net.prototxt'
    solver_parameters.base_lr = 0.5e-4
    solver_parameters.momentum = 0.95
    solver_parameters.momentum2 = 0.999
    solver_parameters.delta = 1e-8
    solver_parameters.weight_decay = 0.000005
    solver_parameters.lr_policy = 'inv'
    solver_parameters.gamma = 0.0001
    solver_parameters.power = 0.75
    solver_parameters.snapshot = 2000
    solver_parameters.snapshot_prefix = 'net'
    solver_parameters.type = 'Adam'
    if trained_until > 0:
        solver_parameters.resume_from = 'net_iter_' + str(trained_until) + '.solverstate'
    else:
        solver_parameters.resume_from = None
    solver_parameters.train_state.add_stage(phase)

    # input and output shapes of the network, needed to formulate matching batch 
    # requests
    input_shape = (196,)*3
    output_shape = (92,)*3

    # arrays to request for each batch
    request = BatchRequest()
    request.add_array_request(ArrayKeys.RAW, input_shape)
    request.add_array_request(ArrayKeys.GT_LABELS, output_shape)
    request.add_array_request(ArrayKeys.GT_MASK, output_shape)
    request.add_array_request(ArrayKeys.GT_AFFINITIES, output_shape)
    if phase == 'euclid':
        request.add_array_request(ArrayKeys.LOSS_SCALE, output_shape)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(

        # provide arrays from the given HDF datasets
        Hdf5Source(
            sample,
            datasets = {
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_LABELS: 'volumes/labels/neuron_ids',
                ArrayKeys.GT_MASK: 'volumes/labels/mask',
            }
        ) +

        # ensure RAW is in float in [0,1]
        Normalize() +

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to 
        # the boundary of the available data
        Pad(
            {
                ArrayKeys.RAW: Coordinate((100, 100, 100)),
                ArrayKeys.GT_MASK: Coordinate((100, 100, 100))
            }
        ) +

        # chose a random location inside the provided arrays
        RandomLocation() +

        # reject batches wich do contain less than 50% labelled data
        Reject()

        for sample in samples
    )

    # attach data sources to training pipeline
    train_pipeline = (

        data_sources +

        # randomly select any of the data sources
        RandomProvider() +

        # elastically deform and rotate
        ElasticAugment([40,40,40], [2,2,2], [0,math.pi/2.0], prob_slip=0.01, max_misalign=1, subsample=8) +

        # randomly mirror and transpose
        SimpleAugment() +

        # grow a 0-boundary between labelled objects
        GrowBoundary(steps=4) +

        # relabel connected label components inside the batch
        SplitAndRenumberSegmentationLabels() +

        # compute ground-truth affinities from labels
        AddGtAffinities(malis.mknhood3d()) +

        # add a LOSS_SCALE array to balance positive and negative classes for 
        # Euclidean training
        BalanceAffinityLabels() +

        # randomly scale and shift intensities
        IntensityAugment(0.9, 1.1, -0.1, 0.1) +

        # ensure RAW is in [-1,1]
        IntensityScaleShift(2,-1) +

        # use 10 workers to pre-cache batches of the above pipeline
        PreCache(
            cache_size=40,
            num_workers=10) +

        # perform one training iteration
        Train(solver_parameters, use_gpu=gpu) +

        # save every 100th batch into an HDF5 file for manual inspection
        Snapshot(
            every=100,
            output_filename='batch_{iteration}.hdf',
            additional_request=BatchRequest({ArrayKeys.LOSS_GRADIENT: request.arrays[ArrayKeys.GT_AFFINITIES]})) +

        # add useful profiling stats to identify bottlenecks
        PrintProfilingStats(every=10)
    )

    iterations = max_iteration - trained_until
    assert iterations >= 0

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(iterations):
            b.request_batch(request)
    print("Training finished")
import numpy as np
import affinities as af
import waterz as wz
import malis as mal

gt = np.load("data/spir_gt.npy")

sample = gt[0:100, 0:400, 0:400]

nhood = mal.mknhood3d(1)

aff = mal.seg_to_affgraph(sample, nhood)

num_act = np.shape(np.unique(sample))[0] - 1

aff = np.asarray(aff, dtype=np.float32)

seg = wz.agglomerate(aff, thresholds=[1])

for segmentation in seg:
    seg = segmentation

num_calc = np.shape(np.unique(seg))[0] - 1

print("Calculated: %i" % num_calc)
print("Actual: %i" % num_act)

#print(np.unique(seg))
#print(np.unique(sample))

if (np.equal(seg, sample).all):