def SkeletonCandidateGenerator(prefix, network_distance, candidates, width, augment): # read in all relevant information segmentation = dataIO.ReadSegmentationData(prefix) world_res = dataIO.Resolution(prefix) # get the radii for the bounding box in grid coordinates radii = (network_distance / world_res[0], network_distance / world_res[1], network_distance / world_res[2]) index = 0 start_time = time.time() continue_printing = True # continue indefinitely while True: # this prevents overflow on the queue - the repeated samples are never used if index >= len(candidates): continue_printing = False index = 0 # get the current candidate candidate = candidates[index] # increment the index index += 1 if continue_printing and not (index % (len(candidates) / 10)): print '{}/{}: {}'.format(index, len(candidates), time.time() - start_time) # rotation equals 0 yield ExtractFeature(segmentation, candidate, width, radii, augment=augment)
def SaveFeatures(prefix_one, prefix_two, threshold, maximum_distance): # read in both segmentation and image files segmentations = (dataIO.ReadSegmentationData(prefix_one), dataIO.ReadSegmentationData(prefix_two)) assert (segmentations[0].shape == segmentations[1].shape) images = (dataIO.ReadImageData(prefix_one), dataIO.ReadImageData(prefix_two)) assert (images[0].shape == images[1].shape) bboxes = (dataIO.GetWorldBBox(prefix_one), dataIO.GetWorldBBox(prefix_two)) world_res = dataIO.Resolution(prefix_one) assert (world_res == dataIO.Resolution(prefix_two)) # get the radii for this feature radii = (maximum_distance / world_res[IB_Z], maximum_distance / world_res[IB_Y], maximum_distance / world_res[IB_X]) width = (2 * radii[IB_Z], 2 * radii[IB_Y], 2 * radii[IB_X], 3) # get all of the candidates for these prefixes candidates = FindCandidates(prefix_one, prefix_two, threshold, maximum_distance, True) ncandidates = len(candidates) # iterate over all candidates for iv, candidate in enumerate(candidates): # get the example with zero rtation example = ExtractFeature(segmentations, images, bboxes, candidate, width, radii, 0) # compress the channels compressed_output = np.zeros((width[IB_Z], width[IB_Y], width[IB_X]), dtype=np.uint8) compressed_output[example[0, :, :, :, 0] == 1] = 1 compressed_output[example[0, :, :, :, 1] == 1] = 2 # both candidates are present at this location compressed_output[np.logical_and(example[0, :, :, :, 0] == 1, example[0, :, :, :, 1] == 1)] = 3 # save the output file filename = 'features/ebro/{}-{}/{}-{}nm-{:05d}.h5'.format( prefix_one, prefix_two, threshold, maximum_distance, iv) dataIO.WriteH5File(compressed_output, filename, 'main')
def SkeletonCandidateGenerator(prefix, network_distance, positive_candidates, negative_candidates, parameters, width): # get the number of channels for the data nchannels = width[0] npositive_candidates = len(positive_candidates) nnegative_candidates = len(negative_candidates) # read in all relevant information segmentation = dataIO.ReadSegmentationData(prefix) world_res = dataIO.Resolution(prefix) # get the radii for the relevant region radii = (network_distance / world_res[IB_Z], network_distance / world_res[IB_Y], network_distance / world_res[IB_X]) # determine the total number of epochs batch_size = parameters['batch_size'] examples = np.zeros((batch_size, nchannels, width[IB_Z + 1], width[IB_Y + 1], width[IB_X + 1]), dtype=np.float32) labels = np.zeros(batch_size, dtype=np.float32) random.shuffle(positive_candidates) random.shuffle(negative_candidates) positive_index = 0 negative_index = 0 while True: # randomly choose elements for the batch for iv in range(batch_size / 2): positive_candidate = positive_candidates[positive_index] negative_candidate = negative_candidates[negative_index] examples[2 * iv, :, :, :, :] = ExtractFeature( segmentation, positive_candidate, width, radii) labels[2 * iv] = positive_candidate.ground_truth examples[2 * iv + 1, :, :, :, :] = ExtractFeature( segmentation, negative_candidate, width, radii) labels[2 * iv + 1] = negative_candidate.ground_truth positive_index += 1 if positive_index == npositive_candidates: random.shuffle(positive_candidates) positive_index = 0 negative_index += 1 if negative_index == nnegative_candidates: random.shuffle(negative_candidates) negative_index = 0 yield (examples, labels)
def EbroCandidateGenerator(prefix_one, prefix_two, maximum_distance, candidates, width): # read in all relevant information segmentations = (dataIO.ReadSegmentationData(prefix_one), dataIO.ReadSegmentationData(prefix_two)) assert (segmentations[0].shape == segmentations[1].shape) images = (dataIO.ReadImageData(prefix_one), dataIO.ReadImageData(prefix_two)) assert (images[0].shape == images[1].shape) bboxes = (dataIO.GetWorldBBox(prefix_one), dataIO.GetWorldBBox(prefix_two)) world_res = dataIO.Resolution(prefix_one) assert (world_res == dataIO.Resolution(prefix_two)) # get the radii for the relevant region radii = (maximum_distance / world_res[IB_Z], maximum_distance / world_res[IB_Y], maximum_distance / world_res[IB_X]) index = 0 start_time = time.time() while True: # prevent overflow if index >= len(candidates): index = 0 candidate = candidates[index] index += 1 # rotation equals 0 yield ExtractFeature(segmentations, images, bboxes, candidate, width, radii, 0)
def NuclearCandidateGenerator(prefix, network_distance, candidates, parameters, width): # get the number of channels for the data nchannels = width[0] # read in all relevant information segmentation = dataIO.ReadSegmentationData(prefix) world_res = dataIO.Resolution(prefix) # get the radii for the relevant region radii = (network_distance / world_res[IB_Z], network_distance / world_res[IB_Y], network_distance / world_res[IB_X]) # determine the total number of epochs if parameters['augment']: rotations = 16 else: rotations = 1 ncandidates = len(candidates) batch_size = parameters['batch_size'] if rotations * ncandidates % batch_size: nbatches = (rotations * ncandidates / batch_size) + 1 else: nbatches = (rotations * ncandidates / batch_size) examples = np.zeros((batch_size, nchannels, width[IB_Z + 1], width[IB_Y + 1], width[IB_X + 1]), dtype=np.float32) labels = np.zeros(batch_size, dtype=np.float32) while True: index = 0 for _ in range(nbatches): for iv in range(batch_size): # get the candidate index and the rotation rotation = index / ncandidates candidate = candidates[index % ncandidates] # get the example and label examples[iv, :, :, :, :] = ExtractFeature( segmentation, candidate, width, radii, rotation) labels[iv] = candidate.ground_truth # provide overflow relief index += 1 if index >= ncandidates * rotations: index = 0 yield (examples, labels)
def Oracle(prefix, threshold, maximum_distance, endpoint_distance, network_distance, filtersize=0): # get all of the candidates positive_candidates = FindCandidates(prefix, threshold, maximum_distance, endpoint_distance, network_distance, 'positive') negative_candidates = FindCandidates(prefix, threshold, maximum_distance, endpoint_distance, network_distance, 'negative') candidates = positive_candidates + negative_candidates # read in all relevant information segmentation = dataIO.ReadSegmentationData(prefix) gold = dataIO.ReadGoldData(prefix) seg2gold_mapping = seg2gold.Mapping(segmentation, gold) # create the union find data structure max_value = np.amax(segmentation) + 1 union_find = [unionfind.UnionFindElement(iv) for iv in range(max_value)] # iterate over all candidates and collapse edges for candidate in candidates: label_one = candidate.labels[0] label_two = candidate.labels[1] if not seg2gold_mapping[label_one] or not seg2gold_mapping[label_two]: continue if (seg2gold_mapping[label_one] == seg2gold_mapping[label_two]): unionfind.Union(union_find[label_one], union_find[label_two]) # create a mapping for the labels mapping = np.zeros(max_value, dtype=np.int64) for iv in range(max_value): mapping[iv] = unionfind.Find(union_find[iv]).label segmentation = seg2seg.MapLabels(segmentation, mapping) comparestacks.CremiEvaluate(segmentation, gold, dilate_ground_truth=1, mask_ground_truth=True, mask_segmentation=False, filtersize=filtersize)
def SaveFeatures(prefix, threshold, maximum_distance, network_distance): # make sure the folder for this model prefix exists output_folder = 'features/skeleton/{}'.format(prefix) if not os.path.exists(output_folder): os.makedirs(output_folder) # read in relevant information segmentation = dataIO.ReadSegmentationData(prefix) grid_size = segmentation.shape world_res = dataIO.Resolution(prefix) # get the radii for the bounding box radii = (maximum_distance / world_res[IB_Z], maximum_distance / world_res[IB_Y], maximum_distance / world_res[IB_X]) width = (2 * radii[IB_Z], 2 * radii[IB_Y], 2 * radii[IB_X], 3) # read all candidates candidates = FindCandidates(prefix, threshold, maximum_distance, network_distance, inference=True) ncandidates = len(candidates) for iv, candidate in enumerate(candidates): # get an example with zero rotation example = ExtractFeature(segmentation, candidate, width, radii, 0) # compress the channels compressed_output = np.zeros((width[IB_Z], width[IB_Y], width[IB_X]), dtype=np.uint8) compressed_output[example[0, :, :, :, 0] == 1] = 1 compressed_output[example[0, :, :, :, 1] == 1] = 2 # save the output file filename = 'features/skeleton/{}/{}-{}nm-{}nm-{:05d}.h5'.format( prefix, threshold, maximum_distance, network_distance, iv) dataIO.WriteH5File(compressed_output, filename, 'main')
def Train(prefix_one, prefix_two, model_prefix, threshold, maximum_distance, width, parameters): # identify convenient variables nchannels = width[3] starting_epoch = parameters['starting_epoch'] iterations = parameters['iterations'] batch_size = parameters['batch_size'] initial_learning_rate = parameters['initial_learning_rate'] decay_rate = parameters['decay_rate'] # architecture parameters activation = parameters['activation'] double_conv = parameters['double_conv'] normalization = parameters['normalization'] optimizer = parameters['optimizer'] weights = parameters['weights'] # create the model model = Sequential() # add all layers to the model AddConvolutionalLayer(model, 16, (3, 3, 3), 'valid', activation, normalization, width) if double_conv: AddConvolutionalLayer(model, 16, (3, 3, 3), 'valid', activation, normalization) AddPoolingLayer(model, (1, 2, 2), 0.0, normalization) AddConvolutionalLayer(model, 32, (3, 3, 3), 'valid', activation, normalization) if double_conv: AddConvolutionalLayer(model, 32, (3, 3, 3), 'valid', activation, normalization) AddPoolingLayer(model, (1, 2, 2), 0.0, normalization) AddConvolutionalLayer(model, 64, (3, 3, 3), 'valid', activation, normalization) if double_conv: AddConvolutionalLayer(model, 64, (3, 3, 3), 'valid', activation, normalization) AddPoolingLayer(model, (2, 2, 2), 0.0, normalization) AddConvolutionalLayer(model, 128, (3, 3, 3), 'valid', activation, normalization) if double_conv: AddConvolutionalLayer(model, 128, (3, 3, 3), 'valid', activation, normalization) AddPoolingLayer(model, (2, 2, 2), 0.0, normalization) AddFlattenLayer(model) AddDenseLayer(model, 512, 0.0, activation, normalization) AddDenseLayer(model, 1, 0.0, 'sigmoid', False) # compile the model if optimizer == 'adam': opt = Adam(lr=initial_learning_rate, decay=decay_rate, beta_1=0.99, beta_2=0.999, epsilon=1e-08) elif optimizer == 'sgd': opt = SGD(lr=initial_learning_rate, decay=decay_rate, momentum=0.9, nesterov=True) model.compile(loss='mean_squared_error', optimizer=opt) # make sure the folder for the model prefix exists root_location = model_prefix.rfind('/') output_folder = model_prefix[:root_location] if not os.path.exists(output_folder): os.makedirs(output_folder) # write out the network parameters to a file WriteLogfiles(model, model_prefix, parameters) # read in all relevant information segmentations = (dataIO.ReadSegmentationData(prefix_one), dataIO.ReadSegmentationData(prefix_two)) assert (segmentations[0].shape == segmentations[1].shape) images = (dataIO.ReadImageData(prefix_one), dataIO.ReadImageData(prefix_two)) assert (images[0].shape == images[1].shape) bboxes = (dataIO.GetWorldBBox(prefix_one), dataIO.GetWorldBBox(prefix_two)) grid_size = segmentations[0].shape world_res = dataIO.Resolution(prefix_one) assert (world_res == dataIO.Resolution(prefix_two)) # get the radii for the relevant region radii = (maximum_distance / world_res[IB_Z], maximum_distance / world_res[IB_Y], maximum_distance / world_res[IB_X]) # get the candidate between these two prefixes candidates = FindCandidates(prefix_one, prefix_two, threshold, maximum_distance, inference=False) ncandidates = len(candidates) # determine the total number of epochs if parameters['augment']: rotations = 16 else: rotations = 1 if rotations * ncandidates % batch_size: nepochs = (iterations * rotations * ncandidates / batch_size) + 1 else: nepochs = (iterations * rotations * ncandidates / batch_size) # need to adjust learning rate and load in existing weights if starting_epoch == 1: index = 0 else: nexamples = starting_epoch * batch_size current_learning_rate = initial_learning_rate / (1.0 + nexamples * decay_rate) backend.set_value(model.optimizer.lr, current_learning_rate) index = (starting_epoch * batch_size) % (ncandidates * rotations) model.load_weights('{}-{}.h5'.format(model_prefix, starting_epoch)) # iterate for every epoch start_time = time.time() for epoch in range(starting_epoch, nepochs + 1): # print statistics if not epoch % 20: print '{}/{} in {:4f} seconds'.format(epoch, nepochs, time.time() - start_time) start_time = time.time() # create arrays for examples and labels examples = np.zeros((batch_size, width[IB_Z], width[IB_Y], width[IB_X], nchannels), dtype=np.uint8) labels = np.zeros((batch_size, 1), dtype=np.uint8) for iv in range(batch_size): # get the index and the rotation rotation = index / ncandidates candidate = candidates[index % ncandidates] # get the example and label examples[iv,:,:,:,:] = ExtractFeature(segmentations, images, bboxes, candidate, width, radii, rotation) labels[iv,:] = candidate.ground_truth # provide overflow relief index += 1 if index >= ncandidates * rotations: index = 0 # fit the model model.fit(examples, labels, epochs=1, verbose=0, class_weight=weights) # save for every 1000 examples if not epoch % (1000 / batch_size): json_string = model.to_json() open('{}-{}.json'.format(model_prefix, epoch), 'w').write(json_string) model.save_weights('{}-{}.h5'.format(model_prefix, epoch)) # update the learning rate nexamples = epoch * batch_size current_learning_rate = initial_learning_rate / (1.0 + nexamples * decay_rate) backend.set_value(model.optimizer.lr, current_learning_rate) # save the fully trained model json_string = model.to_json() open('{}.json'.format(model_prefix), 'w').write(json_string) model.save_weights('{}.h5'.format(model_prefix))
def GenerateFeatures(prefix, threshold, network_distance): # read in the relevant information segmentation = dataIO.ReadSegmentationData(prefix) gold = dataIO.ReadGoldData(prefix) assert (segmentation.shape == gold.shape) zres, yres, xres = segmentation.shape # get the mapping from the segmentation to gold seg2gold_mapping = seg2gold.Mapping(segmentation, gold, low_threshold=0.10, high_threshold=0.80) # remove small connected components segmentation = seg2seg.RemoveSmallConnectedComponents( segmentation, threshold=threshold).astype(np.int64) max_label = np.amax(segmentation) + 1 # get the grid size and the world resolution grid_size = segmentation.shape world_res = dataIO.Resolution(prefix) # get the radius in grid coordinates network_radii = np.int64((network_distance / world_res[IB_Z], network_distance / world_res[IB_Y], network_distance / world_res[IB_X])) # get all of the skeletons skeletons, _, _ = dataIO.ReadSkeletons(prefix, segmentation) npositive_instances = [0 for _ in range(10)] nnegative_instances = [0 for _ in range(10)] positive_candidates = [] negative_candidates = [] # iterate over all skeletons for skeleton in skeletons: label = skeleton.label joints = skeleton.joints # iterate over all joints for joint in joints: # get the gold value at this location location = joint.GridPoint() gold_label = gold[location[IB_Z], location[IB_Y], location[IB_X]] # make sure the bounding box fits valid_location = True for dim in range(NDIMS): if location[dim] - network_radii[dim] < 0: valid_location = False if location[dim] + network_radii[dim] > grid_size[dim]: valid_location = False if not valid_location: continue if not gold_label: continue neighbors = joint.Neighbors() should_split = False if len(neighbors) <= 2: continue # get the gold for every neighbor for neighbor in neighbors: neighbor_location = neighbor.GridPoint() neighbor_gold_label = gold[neighbor_location[IB_Z], neighbor_location[IB_Y], neighbor_location[IB_X]] # get the gold value here if not gold_label == neighbor_gold_label and gold_label and neighbor_gold_label: should_split = True if should_split: npositive_instances[len(neighbors)] += 1 else: nnegative_instances[len(neighbors)] += 1 candidate = NuclearCandidate(label, location, should_split) if should_split: positive_candidates.append(candidate) else: negative_candidates.append(candidate) train_filename = 'features/nuclear/{}-{}-{}nm-training.candidates'.format( prefix, threshold, network_distance) validation_filename = 'features/nuclear/{}-{}-{}nm-validation.candidates'.format( prefix, threshold, network_distance) forward_filename = 'features/nuclear/{}-{}-{}nm-inference.candidates'.format( prefix, threshold, network_distance) SaveCandidates(train_filename, positive_candidates, negative_candidates, inference=False, validation=False) SaveCandidates(validation_filename, positive_candidates, negative_candidates, inference=False, validation=True) SaveCandidates(forward_filename, positive_candidates, negative_candidates, inference=True) print ' Positive Candidates: {}'.format(len(positive_candidates)) print ' Negative Candidates: {}'.format(len(negative_candidates)) print ' Ratio: {}'.format( len(negative_candidates) / float(len(positive_candidates)))
def GenerateFeatures(prefix, threshold, maximum_distance, network_distance, endpoint_distance, topology): start_time = time.time() # read in the relevant information segmentation = dataIO.ReadSegmentationData(prefix) gold = dataIO.ReadGoldData(prefix) assert (segmentation.shape == gold.shape) zres, yres, xres = segmentation.shape # remove small connceted components thresholded_segmentation = seg2seg.RemoveSmallConnectedComponents( segmentation, threshold=threshold).astype(np.int64) max_label = np.amax(segmentation) + 1 # get the grid size and the world resolution grid_size = segmentation.shape world_res = dataIO.Resolution(prefix) # get the radius in grid coordinates radii = np.int64((maximum_distance / world_res[IB_Z], maximum_distance / world_res[IB_Y], maximum_distance / world_res[IB_X])) network_radii = np.int64((network_distance / world_res[IB_Z], network_distance / world_res[IB_Y], network_distance / world_res[IB_X])) # get all of the skeletons if topology: skeletons, endpoints = dataIO.ReadTopologySkeletons( prefix, thresholded_segmentation) else: skeletons, _, endpoints = dataIO.ReadSWCSkeletons( prefix, thresholded_segmentation) # get the set of all considered pairs endpoint_candidates = [set() for _ in range(len(endpoints))] for ie, endpoint in enumerate(endpoints): # extract the region around this endpoint label = endpoint.label centroid = endpoint.GridPoint() # find the candidates near this endpoint candidates = set() candidates.add(0) FindNeighboringCandidates(thresholded_segmentation, centroid, candidates, maximum_distance, world_res) for candidate in candidates: # skip extracellular if not candidate: continue endpoint_candidates[ie].add(candidate) # get a mapping from the labels to indices in skeletons and endpoints label_to_index = [-1 for _ in range(max_label)] for ie, skeleton in enumerate(skeletons): label_to_index[skeleton.label] = ie # begin pruning the candidates based on the endpoints endpoint_pairs = {} # find the smallest pair between endpoints smallest_distances = {} midpoints = {} for ie, endpoint in enumerate(endpoints): label = endpoint.label for neighbor_label in endpoint_candidates[ie]: smallest_distances[(label, neighbor_label)] = endpoint_distance + 1 smallest_distances[(neighbor_label, label)] = endpoint_distance + 1 for ie, endpoint in enumerate(endpoints): # get the endpoint location label = endpoint.label # go through all currently considered endpoints for neighbor_label in endpoint_candidates[ie]: for neighbor_endpoint in skeletons[ label_to_index[neighbor_label]].endpoints: # get the distance deltas = endpoint.WorldPoint( world_res) - neighbor_endpoint.WorldPoint(world_res) distance = math.sqrt(deltas[IB_Z] * deltas[IB_Z] + deltas[IB_Y] * deltas[IB_Y] + deltas[IB_X] * deltas[IB_X]) if distance < smallest_distances[(label, neighbor_label)]: midpoint = (endpoint.GridPoint() + neighbor_endpoint.GridPoint()) / 2 # find the closest pair of endpoints smallest_distances[(label, neighbor_label)] = distance smallest_distances[(neighbor_label, label)] = distance # add to the dictionary endpoint_pairs[(label, neighbor_label)] = (endpoint, neighbor_endpoint) endpoint_pairs[(neighbor_label, label)] = (neighbor_endpoint, endpoint) midpoints[(label, neighbor_label)] = midpoint midpoints[(neighbor_label, label)] = midpoint # create list of candidates positive_candidates = [] negative_candidates = [] undetermined_candidates = [] for ie, match in enumerate(endpoint_pairs): print '{}/{}'.format(ie, len(endpoint_pairs)) endpoint_one = endpoint_pairs[match][0] endpoint_two = endpoint_pairs[match][1] label_one = endpoint_one.label label_two = endpoint_two.label if label_two > label_one: continue # extract a bounding box around this midpoint midz, midy, midx = midpoints[(label_one, label_two)] zmin = max(0, midz - network_radii[IB_Z]) ymin = max(0, midy - network_radii[IB_Y]) xmin = max(0, midx - network_radii[IB_X]) zmax = min(zres - 1, midz + network_radii[IB_Z] + 1) ymax = min(yres - 1, midy + network_radii[IB_Y] + 1) xmax = min(xres - 1, midx + network_radii[IB_X] + 1) extracted_segmentation = segmentation[zmin:zmax, ymin:ymax, xmin:xmax] extracted_gold = gold[zmin:zmax, ymin:ymax, xmin:xmax] extracted_seg2gold_mapping = seg2gold.Mapping(extracted_segmentation, extracted_gold, match_threshold=0.70, nonzero_threshold=0.40) if label_one > extracted_seg2gold_mapping.size: continue if label_two > extracted_seg2gold_mapping.size: continue gold_one = extracted_seg2gold_mapping[label_one] gold_two = extracted_seg2gold_mapping[label_two] ground_truth = (gold_one == gold_two) candidate = SkeletonCandidate((label_one, label_two), midpoints[(label_one, label_two)], ground_truth) if not extracted_seg2gold_mapping[ label_one] or not extracted_seg2gold_mapping[label_two]: undetermined_candidates.append(candidate) elif ground_truth: positive_candidates.append(candidate) else: negative_candidates.append(candidate) # save positive and negative candidates separately positive_filename = 'features/skeleton/{}-{}-{}nm-{}nm-{}nm-positive.candidates'.format( prefix, threshold, maximum_distance, endpoint_distance, network_distance) negative_filename = 'features/skeleton/{}-{}-{}nm-{}nm-{}nm-negative.candidates'.format( prefix, threshold, maximum_distance, endpoint_distance, network_distance) undetermined_filename = 'features/skeleton/{}-{}-{}nm-{}nm-{}nm-undetermined.candidates'.format( prefix, threshold, maximum_distance, endpoint_distance, network_distance) SaveCandidates(positive_filename, positive_candidates) SaveCandidates(negative_filename, negative_candidates) SaveCandidates(undetermined_filename, undetermined_candidates) print 'Positive candidates: {}'.format(len(positive_candidates)) print 'Negative candidates: {}'.format(len(negative_candidates)) print 'Undetermined candidates: {}'.format(len(undetermined_candidates))
def GenerateFeatures(prefix_one, prefix_two, threshold, maximum_distance): # read in all relevant information segmentation_one = dataIO.ReadSegmentationData(prefix_one) segmentation_two = dataIO.ReadSegmentationData(prefix_two) assert (segmentation_one.shape == segmentation_two.shape) bbox_one = dataIO.GetWorldBBox(prefix_one) bbox_two = dataIO.GetWorldBBox(prefix_two) world_res = dataIO.Resolution(prefix_one) assert (world_res == dataIO.Resolution(prefix_two)) # get the radii for the relevant region radii = (int(maximum_distance / world_res[IB_Z] + 0.5), int(maximum_distance / world_res[IB_Y] + 0.5), int(maximum_distance / world_res[IB_X] + 0.5)) # parse out small segments segmentation_one = seg2seg.RemoveSmallConnectedComponents(segmentation_one, min_size=threshold) segmentation_two = seg2seg.RemoveSmallConnectedComponents(segmentation_two, min_size=threshold) # get the bounding box for the intersection world_box = ib3shapes.IBBox(bbox_one.mins, bbox_one.maxs) world_box.Intersection(bbox_two) # get the mins and maxs of truncated box mins_one = WorldToGrid(world_box.mins, bbox_one) mins_two = WorldToGrid(world_box.mins, bbox_two) maxs_one = WorldToGrid(world_box.maxs, bbox_one) maxs_two = WorldToGrid(world_box.maxs, bbox_two) # get the relevant subsections segmentation_one = segmentation_one[mins_one[IB_Z]:maxs_one[IB_Z], mins_one[IB_Y]:maxs_one[IB_Y], mins_one[IB_X]:maxs_one[IB_X]] segmentation_two = segmentation_two[mins_two[IB_Z]:maxs_two[IB_Z], mins_two[IB_Y]:maxs_two[IB_Y], mins_two[IB_X]:maxs_two[IB_X]] # create an emptu set and add dumby variable for numba candidates_set = set() # this set represents tuples of labels from GRID_ONE and GRID_TWO candidates_set.add((np.uint64(0), np.uint64(0))) FindOverlapCandidates(segmentation_one, segmentation_two, candidates_set) # get the reverse mappings forward_mapping_one, reverse_mapping_one = seg2seg.ReduceLabels(segmentation_one) forward_mapping_two, reverse_mapping_two = seg2seg.ReduceLabels(segmentation_two) # get the number of unique labels nlabels_one = reverse_mapping_one.size nlabels_two = reverse_mapping_two.size # calculate the center of overlap regions sums = np.zeros((nlabels_one, nlabels_two, 3), dtype=np.uint64) counter = np.zeros((nlabels_one, nlabels_two), dtype=np.uint64) FindCenters(segmentation_one, segmentation_two, forward_mapping_one, forward_mapping_two, sums, counter) # get the number of occurrences of all labels _, counts_one = np.unique(segmentation_one, return_counts=True) _, counts_two = np.unique(segmentation_two, return_counts=True) # iterate through candidate and locate centers candidates = [] centers = [] counts = [] for candidate in candidates_set: # skip extracellular space if not candidate[0] or not candidate[1]: continue # get forward mapping index_one = forward_mapping_one[candidate[0]] index_two = forward_mapping_two[candidate[1]] count = counter[index_one][index_two] center = (int(sums[index_one, index_two, IB_Z] / count + 0.5), int(sums[index_one, index_two, IB_Y] / count + 0.5), int(sums[index_one, index_two, IB_X] / count + 0.5)) # append to the lists candidates.append(candidate) centers.append(center) counts.append((counts_one[index_one], counts_two[index_two], count)) # find which dimension causes overlap if not bbox_one.mins[IB_X] == bbox_two.mins[IB_X]: overlap = IB_X if not bbox_one.mins[IB_Y] == bbox_two.mins[IB_Y]: overlap = IB_Y if not bbox_one.mins[IB_Z] == bbox_two.mins[IB_Z]: overlap = IB_Z # prune the candidates indices = PruneCandidates(segmentation_one, segmentation_two, candidates, centers, radii, overlap) pruned_candidates = [] pruned_centers = [] pruned_counts = [] for index in indices: # add the candidates pruned_candidates.append(candidates[index]) pruned_counts.append(counts[index]) center = (centers[index][IB_Z] + world_box.mins[IB_Z], centers[index][IB_Y] + world_box.mins[IB_Y], centers[index][IB_X] + world_box.mins[IB_X]) pruned_centers.append(center) # save all features SaveFeatures(prefix_one, prefix_two, pruned_candidates, pruned_centers, pruned_counts, threshold, maximum_distance)
def Agglomerate(prefix, model_prefix, threshold=0.5): # read the segmentation data segmentation = dataIO.ReadSegmentationData(prefix) # get the multicut filename (with graph weights) multicut_filename = 'multicut/{}-{}.graph'.format(model_prefix, prefix) # get the maximum segmentation value max_value = np.amax(segmentation) + 1 # create union find data structure union_find = [UnionFind.UnionFindElement(iv) for iv in range(max_value)] # read in all of the labels and merge the result with open(multicut_filename, 'rb') as fd: # read the number of vertices and edges nvertices, nedges, = struct.unpack('QQ', fd.read(16)) # read in all of the edges for ie in range(nedges): # read in both labels label_one, label_two, = struct.unpack('QQ', fd.read(16)) # skip over the reduced labels fd.read(16) # read in the edge weight edge_weight, = struct.unpack('d', fd.read(8)) # merge label one and label two in the union find data structure if (edge_weight > threshold): UnionFind.Union(union_find[label_one], union_find[label_two]) # create a mapping mapping = np.zeros(max_value, dtype=np.int64) # update the segmentation for iv in range(max_value): label = UnionFind.Find(union_find[iv]).label mapping[iv] = label # update the labels agglomerated_segmentation = seg2seg.MapLabels(segmentation, mapping) gold_filename = 'gold/{}_gold.h5'.format(prefix) # TODO fix this code temporary filename agglomeration_filename = 'multicut/{}-agglomerate.h5'.format(prefix) # temporary - write h5 file dataIO.WriteH5File(agglomerated_segmentation, agglomeration_filename, 'stack') import time start_time = time.time() print 'Agglomeration - {}:'.format(threshold) # create the command line command = '~/software/PixelPred2Seg/comparestacks --stack1 {} --stackbase {} --dilate1 1 --dilatebase 1 --relabel1 --relabelbase --filtersize 100 --anisotropic'.format( agglomeration_filename, gold_filename) # execute the command os.system(command) print time.time() - start_time
def MergeGroundTruth(prefix, model_prefix): # read the segmentation data segmentation = dataIO.ReadSegmentationData(prefix) # get the multicut filename (with graph weights) multicut_filename = 'multicut/{}-{}.graph'.format(model_prefix, prefix) # read the gold data gold = dataIO.ReadGoldData(prefix) # read in the segmentation to gold mapping mapping = seg2gold.Mapping(segmentation, gold) # get the maximum segmentation value max_value = np.amax(segmentation) # create union find data structure union_find = [UnionFind.UnionFindElement(iv) for iv in range(max_value)] # read in all of the labels with open(multicut_filename, 'rb') as fd: # read the number of vertices and edges nvertices, nedges, = struct.unpack('QQ', fd.read(16)) # read in all of the edges for ie in range(nedges): # read in the two labels label_one, label_two, = struct.unpack('QQ', fd.read(16)) # skip over the reduced labels and edge weight fd.read(24) # if the labels are the same and the mapping is non zero if mapping[label_one] == mapping[label_two] and mapping[label_one]: UnionFind.Union(union_find[label_one], union_find[label_two]) # create a mapping mapping = np.zeros(max_value, dtype=np.int64) # update the segmentation for iv in range(max_value): label = UnionFind.Find(union_find[iv]).label mapping[iv] = label merged_segmentation = seg2seg.MapLabels(segmentation, mapping) gold_filename = 'gold/{}_gold.h5'.format(prefix) # TODO fix this code temporary filename truth_filename = 'multicut/{}-truth.h5'.format(prefix) # temporary write h5 file dataIO.WriteH5File(merged_segmentation, truth_filename, 'stack') import time start_time = time.time() print 'Ground truth: ' # create the command line command = '~/software/PixelPred2Seg/comparestacks --stack1 {} --stackbase {} --dilate1 1 --dilatebase 1 --relabel1 --relabelbase --filtersize 100 --anisotropic'.format( truth_filename, gold_filename) # execute the command os.system(command) print time.time() - start_time