Пример #1
0
def GenerateExamples(prefix, cutoff=500):
    gold = dataIO.ReadGoldData(prefix)
    labels, counts = np.unique(gold, return_counts=True)

    filename = 'benchmarks/skeleton/{}-skeleton-benchmark-examples.bin'.format(prefix)
    with open(filename, 'wb') as fd:
        fd.write(struct.pack('q', cutoff))
        if labels[0] == 0: cutoff += 1
        for ie, (count, label) in enumerate(sorted(zip(counts, labels), reverse=True)):
            if not label: continue
            # don't include more than cutoff examples
            if ie == cutoff: break
            fd.write(struct.pack('q', label))
Пример #2
0
def Oracle(prefix,
           threshold,
           maximum_distance,
           endpoint_distance,
           network_distance,
           filtersize=0):
    # get all of the candidates
    positive_candidates = FindCandidates(prefix, threshold, maximum_distance,
                                         endpoint_distance, network_distance,
                                         'positive')
    negative_candidates = FindCandidates(prefix, threshold, maximum_distance,
                                         endpoint_distance, network_distance,
                                         'negative')
    candidates = positive_candidates + negative_candidates

    # read in all relevant information
    segmentation = dataIO.ReadSegmentationData(prefix)
    gold = dataIO.ReadGoldData(prefix)

    seg2gold_mapping = seg2gold.Mapping(segmentation, gold)

    # create the union find data structure
    max_value = np.amax(segmentation) + 1
    union_find = [unionfind.UnionFindElement(iv) for iv in range(max_value)]

    # iterate over all candidates and collapse edges
    for candidate in candidates:
        label_one = candidate.labels[0]
        label_two = candidate.labels[1]

        if not seg2gold_mapping[label_one] or not seg2gold_mapping[label_two]:
            continue
        if (seg2gold_mapping[label_one] == seg2gold_mapping[label_two]):
            unionfind.Union(union_find[label_one], union_find[label_two])

    # create a mapping for the labels
    mapping = np.zeros(max_value, dtype=np.int64)
    for iv in range(max_value):
        mapping[iv] = unionfind.Find(union_find[iv]).label

    segmentation = seg2seg.MapLabels(segmentation, mapping)
    comparestacks.CremiEvaluate(segmentation,
                                gold,
                                dilate_ground_truth=1,
                                mask_ground_truth=True,
                                mask_segmentation=False,
                                filtersize=filtersize)
Пример #3
0
def EvaluateEndpoints(prefix):
    gold = dataIO.ReadGoldData(prefix)
    max_label = np.amax(gold) + 1

    resolutions = [(iv, iv, iv)
                   for iv in range(30, 210, 10)]  # all downsampled resolutions

    # get the human labeled ground truth
    gt_endpoints = ReadGroundTruth(prefix, max_label)

    best_fscore_precision = 0.0
    best_fscore_recall = 0.0
    best_fscore = 0.0
    algorithm = ''

    min_precision, min_recall = (0.80, 0.90)

    # go through all possible configurations
    for resolution in resolutions:
        # go through parameters for medial axis strategy
        for astar_expansion in [0, 11, 13, 15, 17, 19, 21, 23, 25]:
            fscore, precision, recall = FindEndpointMatches(
                prefix, 'thinning', '{:02d}'.format(astar_expansion),
                resolution, gt_endpoints)

            if (precision > min_precision and recall > min_recall):
                print 'Thinning {:03d}x{:03d}x{:03d} {:02d}'.format(
                    resolution[IB_X], resolution[IB_Y], resolution[IB_Z],
                    astar_expansion)
                print '  F1-Score: {}'.format(fscore)
                print '  Precision: {}'.format(precision)
                print '  Recall: {}'.format(recall)

            if (fscore > best_fscore):
                best_fscore = fscore
                best_fscore_precision = precision
                best_fscore_recall = recall
                algorithm = 'thinning-{:03d}x{:03d}x{:03d}-{:02d}'.format(
                    resolution[IB_X], resolution[IB_Y], resolution[IB_Z],
                    astar_expansion)

            fscore, precision, recall = FindEndpointMatches(
                prefix, 'medial-axis', '{:02d}'.format(astar_expansion),
                resolution, gt_endpoints)

            if (precision > min_precision and recall > min_recall):
                print 'Medial Axis {:03d}x{:03d}x{:03d} {:02d}'.format(
                    resolution[IB_X], resolution[IB_Y], resolution[IB_Z],
                    astar_expansion)
                print '  F1-Score: {}'.format(fscore)
                print '  Precision: {}'.format(precision)
                print '  Recall: {}'.format(recall)

            if (fscore > best_fscore):
                best_fscore = fscore
                best_fscore_precision = precision
                best_fscore_recall = recall
                algorithm = 'medial-axis-{:03d}x{:03d}x{:03d}-{:02d}'.format(
                    resolution[IB_X], resolution[IB_Y], resolution[IB_Z],
                    astar_expansion)

        for tscale in [7, 9, 11, 13, 15, 17]:
            for tbuffer in [1, 2, 3, 4, 5]:
                fscore, precision, recall = FindEndpointMatches(
                    prefix, 'teaser',
                    '{:02d}-{:02d}-00'.format(tscale, tbuffer), resolution,
                    gt_endpoints)

                if (precision > min_precision and recall > min_recall):
                    print 'TEASER {:03d}x{:03d}x{:03d} {:02d} {:02d}'.format(
                        resolution[IB_X], resolution[IB_Y], resolution[IB_Z],
                        tscale, tbuffer)
                    print '  F1-Score: {}'.format(fscore)
                    print '  Precision: {}'.format(precision)
                    print '  Recall: {}'.format(recall)

                if (fscore > best_fscore):
                    best_fscore = fscore
                    best_fscore_precision = precision
                    best_fscore_recall = recall
                    algorithm = 'teaser-{:03d}x{:03d}x{:03d}-{:02d}-{:02d}-00'.format(
                        resolution[IB_X], resolution[IB_Y], resolution[IB_Z],
                        tscale, tbuffer)

    print 'Best method: {}'.format(algorithm)
    print 'F1-Score: {}'.format(best_fscore)
    print 'Precision: {}'.format(best_fscore_precision)
    print 'Recall: {}'.format(best_fscore_recall)
Пример #4
0
def CollapseGraph(prefix, segmentation, vertex_ones, vertex_twos,
                  maintained_edges, algorithm):
    # get the number of edges
    nedges = maintained_edges.shape[0]

    # create the union find data structure and collapse the graph
    max_label = np.amax(segmentation) + 1
    union_find = [unionfind.UnionFindElement(iv) for iv in range(max_label)]

    # go through all of the edges
    for ie in range(nedges):
        # skip if the edge should not collapse
        if maintained_edges[ie]: continue

        # merge these vertices
        vertex_one = vertex_ones[ie]
        vertex_two = vertex_twos[ie]

        unionfind.Union(union_find[vertex_one], union_find[vertex_two])

    # create the mapping and save the result
    mapping = np.zeros(max_label, dtype=np.int64)
    for iv in range(max_label):
        mapping[iv] = unionfind.Find(union_find[iv]).label

    # apply the mapping and save the result
    seg2seg.MapLabels(segmentation, mapping)

    rhoana_filename = 'rhoana/{}-{}.h5'.format(prefix, algorithm)
    dataIO.WriteH5File(segmentation, rhoana_filename, 'main')

    # spawn a new meta file
    dataIO.SpawnMetaFile(prefix, rhoana_filename, 'main')

    # get the variation of information for this result
    new_prefix = rhoana_filename.split('/')[1][:-3]

    # read in the new gold data
    gold = dataIO.ReadGoldData(prefix)

    rand_error, vi = comparestacks.VariationOfInformation(
        new_prefix, segmentation, gold)

    #adapted_rand = comparestacks.adapted_rand(prefix, segmentation, gold)

    print 'Rand Error Full: {}'.format(rand_error[0] + rand_error[1])
    print 'Rand Error Merge: {}'.format(rand_error[0])
    print 'Rand Error Split: {}'.format(rand_error[1])

    print 'Variation of Information Full: {}'.format(vi[0] + vi[1])
    print 'Variation of Information Merge: {}'.format(vi[0])
    print 'Variation of Information Split: {}'.format(vi[1])

    #print 'Adapted Rand: {}'.format(adapted_rand)

    # make sure that the options are either multicut or lifted-multicut
    if 'lifted-multicut' in algorithm: output_folder = 'lifted-multicut'
    elif 'multicut' in algorithm: output_folder = 'multicut'
    elif 'graph-baseline' in algorithm: output_folder = 'graph-baselines'
    else: assert (False)

    with open('{}-results/{}-{}.txt'.format(output_folder, algorithm, prefix),
              'w') as fd:
        fd.write('Rand Error Full: {}\n'.format(rand_error[0] + rand_error[1]))
        fd.write('Rand Error Merge: {}\n'.format(rand_error[0]))
        fd.write('Rand Error Split: {}\n'.format(rand_error[1]))

        fd.write('Variation of Information Full: {}\n'.format(vi[0] + vi[1]))
        fd.write('Variation of Information Merge: {}\n'.format(vi[0]))
        fd.write('Variation of Information Split: {}\n'.format(vi[1]))
Пример #5
0
def CollapseGraph(segmentation, candidates, maintain_edges, probabilities,
                  output_filename):
    ncandidates = len(candidates)

    # get the ground truth and the predictions
    labels = np.zeros(ncandidates, dtype=np.bool)
    for iv in range(ncandidates):
        labels[iv] = candidates[iv].ground_truth

    # create an empty union find data structure
    max_value = np.amax(segmentation) + 1
    union_find = [unionfind.UnionFindElement(iv) for iv in range(max_value)]

    # create adjacency sets for the elements in the segment
    adjacency_sets = [set() for _ in range(max_value)]

    for candidate in candidates:
        label_one = candidate.labels[0]
        label_two = candidate.labels[1]

        adjacency_sets[label_one].add(label_two)
        adjacency_sets[label_two].add(label_one)

    # iterate over the candidates in order of decreasing probability
    zipped = zip(probabilities, [ie for ie in range(ncandidates)])

    for probability, ie in sorted(zipped, reverse=True):
        # skip if the edge is not collapsed
        if maintain_edges[ie]: continue
        # skip if this creates a cycle
        label_one, label_two = candidates[ie].labels

        # get the parent of this label
        label_two_union_find = unionfind.Find(union_find[label_two]).label

        # make sure none of the other adjacent nodes already has this label
        for neighbor_label in adjacency_sets[label_one]:
            if neighbor_label == label_two: continue

        if unionfind.Find(
                union_find[neighbor_label]).label == label_two_union_find:
            maintain_edges[ie] = True

        # skip if the edge is no longer collapsed
        if maintain_edges[ie]: continue
        unionfind.Union(union_find[label_one], union_find[label_two])

    print '\nBorder Constraints\n'
    PrecisionAndRecall(labels, 1 - maintain_edges)

    # for every edge, save if the edge is collapsed
    with open(output_filename, 'wb') as fd:
        fd.write(struct.pack('q', ncandidates))
        for ie in range(ncandidates):
            fd.write(struct.pack('?', maintain_edges[ie]))

    mapping = np.zeros(max_value, dtype=np.int64)
    for iv in range(max_value):
        mapping[iv] = unionfind.Find(union_find[iv]).label

    segmentation = seg2seg.MapLabels(segmentation, mapping)
    gold = dataIO.ReadGoldData('SNEMI3D_train')
    print comparestacks.adapted_rand(segmentation,
                                     gold,
                                     all_stats=False,
                                     dilate_ground_truth=2,
                                     filtersize=0)
Пример #6
0
def Forward(prefix,
            model_prefix,
            segmentation,
            width,
            radius,
            subset,
            evaluate=False,
            threshold_volume=10368000):
    # read in the trained model
    model = model_from_json(open('{}.json'.format(model_prefix), 'r').read())
    model.load_weights('{}-best-loss.h5'.format(model_prefix))

    # get all of the examples
    examples, npositives, nnegatives = CollectExamples(prefix, width, radius,
                                                       subset)

    # get all of the large-small pairings
    pairings = CollectLargeSmallPairs(prefix, width, radius, subset)
    #assert (len(pairings) == examples.shape[0])

    # get the threshold in terms of number of voxels
    resolution = dataIO.Resolution(prefix)
    threshold = int(threshold_volume /
                    (resolution[IB_Z] * resolution[IB_Y] * resolution[IB_X]))

    # get the list of nodes over and under the threshold
    small_segments, large_segments = FindSmallSegments(segmentation, threshold)

    # get all of the probabilities
    probabilities = model.predict_generator(NodeGenerator(examples, width),
                                            examples.shape[0],
                                            max_q_size=1000)

    # save the probabilities to a file
    output_filename = '{}-{}.probabilities'.format(model_prefix, prefix)
    with open(output_filename, 'wb') as fd:
        fd.write(struct.pack('q', examples.shape[0]))
        for ie, (label_one, label_two) in enumerate(pairings):
            fd.write(
                struct.pack('qqd', label_one, label_two, probabilities[ie]))

    # create the correct labels for the ground truth
    ground_truth = np.zeros(npositives + nnegatives, dtype=np.bool)
    for iv in range(npositives):
        ground_truth[iv] = True

    # get the results with labeled data
    predictions = Prob2Pred(np.squeeze(probabilities[:npositives +
                                                     nnegatives]))

    # print the confusion matrix
    output_filename = '{}-{}-inference.txt'.format(model_prefix, prefix)
    PrecisionAndRecall(ground_truth, predictions, output_filename)

    # create a mapping
    small_segment_predictions = dict()
    for small_segment in small_segments:
        small_segment_predictions[small_segment] = set()

    # go through each pairing
    for pairing, probability in zip(pairings, probabilities):
        label_one, label_two = pairing
        # make sure that either label one or two is small and the other is large
        assert ((label_one in small_segments) ^ (label_two in small_segments))

        if label_one in small_segments:
            small_segment = label_one
            large_segment = label_two
        else:
            small_segment = label_two
            large_segment = label_one

        small_segment_predictions[small_segment].add(
            (large_segment, probability[0]))

    # begin to map the small labels
    max_label = np.amax(segmentation) + 1
    mapping = [iv for iv in range(max_label)]

    # look at seg2gold to see how many correct segments are merged
    seg2gold_mapping = seg2gold.Mapping(prefix)

    ncorrect_merges = 0
    nincorrect_merges = 0

    # go through all of the small segments
    for small_segment in small_segments:
        best_probability = -1
        best_large_segment = -1

        # go through all the neighboring large segments
        for large_segment, probability in small_segment_predictions[
                small_segment]:
            if probability > best_probability:
                best_probability = probability
                best_large_segment = large_segment

        # this should almost never happen but if it does just continue
        if best_large_segment == -1 or best_probability < 0.5:
            mapping[small_segment] = small_segment
            continue
        # get all of the best large segments
        else:
            mapping[small_segment] = best_large_segment

        # don't consider undetermined locations
        if seg2gold_mapping[small_segment] < 1 or seg2gold_mapping[
                best_large_segment] < 1:
            continue

        if seg2gold_mapping[small_segment] == seg2gold_mapping[
                best_large_segment]:
            ncorrect_merges += 1
        else:
            nincorrect_merges += 1

    print '\nResults:'
    print '  Correctly Merged: {}'.format(ncorrect_merges)
    print '  Incorrectly Merged: {}'.format(nincorrect_merges)

    with open(output_filename, 'a') as fd:
        fd.write('\nResults:\n')
        fd.write('  Correctly Merged: {}\n'.format(ncorrect_merges))
        fd.write('  Incorrectly Merged: {}\n'.format(nincorrect_merges))

    # save the node mapping in the cache for later
    end2end_mapping = [mapping[iv] for iv in range(max_label)]

    # initiate the mapping to eliminate small segments
    seg2seg.MapLabels(segmentation, mapping)

    # reduce the labels and map again
    mapping, _ = seg2seg.ReduceLabels(segmentation)
    seg2seg.MapLabels(segmentation, mapping)

    # update the end to end mapping with the reduced labels
    for iv in range(max_label):
        end2end_mapping[iv] = mapping[end2end_mapping[iv]]

    # get the model name (first component is architecture and third is node-)
    model_name = model_prefix.split('/')[1]
    output_filename = 'rhoana/{}-reduced-{}.h5'.format(prefix, model_name)
    dataIO.WriteH5File(segmentation, output_filename, 'main')

    # spawn a new meta file
    dataIO.SpawnMetaFile(prefix, output_filename, 'main')

    # save the end to end mapping in the cache
    mapping_filename = 'cache/{}-reduced-{}-end2end.map'.format(
        prefix, model_name)
    with open(mapping_filename, 'wb') as fd:
        fd.write(struct.pack('q', max_label))
        for label in range(max_label):
            fd.write(struct.pack('q', end2end_mapping[label]))

    if evaluate:
        gold = dataIO.ReadGoldData(prefix)

        # run the evaluation framework
        rand_error, vi = comparestacks.VariationOfInformation(
            segmentation, gold)

        # write the output file
        with open('node-results/{}-reduced-{}.txt'.format(prefix, model_name),
                  'w') as fd:
            fd.write('Rand Error Full: {}\n'.format(rand_error[0] +
                                                    rand_error[1]))
            fd.write('Rand Error Merge: {}\n'.format(rand_error[0]))
            fd.write('Rand Error Split: {}\n'.format(rand_error[1]))

            fd.write('Variation of Information Full: {}\n'.format(vi[0] +
                                                                  vi[1]))
            fd.write('Variation of Information Merge: {}\n'.format(vi[0]))
            fd.write('Variation of Information Split: {}\n'.format(vi[1]))
Пример #7
0
def GenerateFeatures(prefix, threshold, network_distance):
    # read in the relevant information
    segmentation = dataIO.ReadSegmentationData(prefix)
    gold = dataIO.ReadGoldData(prefix)
    assert (segmentation.shape == gold.shape)
    zres, yres, xres = segmentation.shape

    # get the mapping from the segmentation to gold
    seg2gold_mapping = seg2gold.Mapping(segmentation,
                                        gold,
                                        low_threshold=0.10,
                                        high_threshold=0.80)

    # remove small connected components
    segmentation = seg2seg.RemoveSmallConnectedComponents(
        segmentation, threshold=threshold).astype(np.int64)
    max_label = np.amax(segmentation) + 1

    # get the grid size and the world resolution
    grid_size = segmentation.shape
    world_res = dataIO.Resolution(prefix)

    # get the radius in grid coordinates
    network_radii = np.int64((network_distance / world_res[IB_Z],
                              network_distance / world_res[IB_Y],
                              network_distance / world_res[IB_X]))

    # get all of the skeletons
    skeletons, _, _ = dataIO.ReadSkeletons(prefix, segmentation)

    npositive_instances = [0 for _ in range(10)]
    nnegative_instances = [0 for _ in range(10)]

    positive_candidates = []
    negative_candidates = []

    # iterate over all skeletons
    for skeleton in skeletons:
        label = skeleton.label
        joints = skeleton.joints

        # iterate over all joints
        for joint in joints:
            # get the gold value at this location
            location = joint.GridPoint()
            gold_label = gold[location[IB_Z], location[IB_Y], location[IB_X]]

            # make sure the bounding box fits
            valid_location = True
            for dim in range(NDIMS):
                if location[dim] - network_radii[dim] < 0:
                    valid_location = False
                if location[dim] + network_radii[dim] > grid_size[dim]:
                    valid_location = False
            if not valid_location: continue

            if not gold_label: continue

            neighbors = joint.Neighbors()
            should_split = False

            if len(neighbors) <= 2: continue

            # get the gold for every neighbor
            for neighbor in neighbors:
                neighbor_location = neighbor.GridPoint()
                neighbor_gold_label = gold[neighbor_location[IB_Z],
                                           neighbor_location[IB_Y],
                                           neighbor_location[IB_X]]

                # get the gold value here

                if not gold_label == neighbor_gold_label and gold_label and neighbor_gold_label:
                    should_split = True

            if should_split: npositive_instances[len(neighbors)] += 1
            else: nnegative_instances[len(neighbors)] += 1

            candidate = NuclearCandidate(label, location, should_split)
            if should_split: positive_candidates.append(candidate)
            else: negative_candidates.append(candidate)

    train_filename = 'features/nuclear/{}-{}-{}nm-training.candidates'.format(
        prefix, threshold, network_distance)
    validation_filename = 'features/nuclear/{}-{}-{}nm-validation.candidates'.format(
        prefix, threshold, network_distance)
    forward_filename = 'features/nuclear/{}-{}-{}nm-inference.candidates'.format(
        prefix, threshold, network_distance)
    SaveCandidates(train_filename,
                   positive_candidates,
                   negative_candidates,
                   inference=False,
                   validation=False)
    SaveCandidates(validation_filename,
                   positive_candidates,
                   negative_candidates,
                   inference=False,
                   validation=True)
    SaveCandidates(forward_filename,
                   positive_candidates,
                   negative_candidates,
                   inference=True)

    print '  Positive Candidates: {}'.format(len(positive_candidates))
    print '  Negative Candidates: {}'.format(len(negative_candidates))
    print '  Ratio: {}'.format(
        len(negative_candidates) / float(len(positive_candidates)))
Пример #8
0
def GenerateFeatures(prefix, threshold, maximum_distance, network_distance,
                     endpoint_distance, topology):
    start_time = time.time()

    # read in the relevant information
    segmentation = dataIO.ReadSegmentationData(prefix)
    gold = dataIO.ReadGoldData(prefix)
    assert (segmentation.shape == gold.shape)
    zres, yres, xres = segmentation.shape

    # remove small connceted components
    thresholded_segmentation = seg2seg.RemoveSmallConnectedComponents(
        segmentation, threshold=threshold).astype(np.int64)
    max_label = np.amax(segmentation) + 1

    # get the grid size and the world resolution
    grid_size = segmentation.shape
    world_res = dataIO.Resolution(prefix)

    # get the radius in grid coordinates
    radii = np.int64((maximum_distance / world_res[IB_Z],
                      maximum_distance / world_res[IB_Y],
                      maximum_distance / world_res[IB_X]))
    network_radii = np.int64((network_distance / world_res[IB_Z],
                              network_distance / world_res[IB_Y],
                              network_distance / world_res[IB_X]))

    # get all of the skeletons
    if topology:
        skeletons, endpoints = dataIO.ReadTopologySkeletons(
            prefix, thresholded_segmentation)
    else:
        skeletons, _, endpoints = dataIO.ReadSWCSkeletons(
            prefix, thresholded_segmentation)

    # get the set of all considered pairs
    endpoint_candidates = [set() for _ in range(len(endpoints))]
    for ie, endpoint in enumerate(endpoints):
        # extract the region around this endpoint
        label = endpoint.label
        centroid = endpoint.GridPoint()

        # find the candidates near this endpoint
        candidates = set()
        candidates.add(0)
        FindNeighboringCandidates(thresholded_segmentation, centroid,
                                  candidates, maximum_distance, world_res)

        for candidate in candidates:
            # skip extracellular
            if not candidate: continue
            endpoint_candidates[ie].add(candidate)

    # get a mapping from the labels to indices in skeletons and endpoints
    label_to_index = [-1 for _ in range(max_label)]
    for ie, skeleton in enumerate(skeletons):
        label_to_index[skeleton.label] = ie

    # begin pruning the candidates based on the endpoints
    endpoint_pairs = {}

    # find the smallest pair between endpoints
    smallest_distances = {}
    midpoints = {}

    for ie, endpoint in enumerate(endpoints):
        label = endpoint.label
        for neighbor_label in endpoint_candidates[ie]:
            smallest_distances[(label, neighbor_label)] = endpoint_distance + 1
            smallest_distances[(neighbor_label, label)] = endpoint_distance + 1

    for ie, endpoint in enumerate(endpoints):
        # get the endpoint location
        label = endpoint.label

        # go through all currently considered endpoints
        for neighbor_label in endpoint_candidates[ie]:
            for neighbor_endpoint in skeletons[
                    label_to_index[neighbor_label]].endpoints:
                # get the distance
                deltas = endpoint.WorldPoint(
                    world_res) - neighbor_endpoint.WorldPoint(world_res)
                distance = math.sqrt(deltas[IB_Z] * deltas[IB_Z] +
                                     deltas[IB_Y] * deltas[IB_Y] +
                                     deltas[IB_X] * deltas[IB_X])

                if distance < smallest_distances[(label, neighbor_label)]:
                    midpoint = (endpoint.GridPoint() +
                                neighbor_endpoint.GridPoint()) / 2

                    # find the closest pair of endpoints
                    smallest_distances[(label, neighbor_label)] = distance
                    smallest_distances[(neighbor_label, label)] = distance

                    # add to the dictionary
                    endpoint_pairs[(label,
                                    neighbor_label)] = (endpoint,
                                                        neighbor_endpoint)
                    endpoint_pairs[(neighbor_label,
                                    label)] = (neighbor_endpoint, endpoint)

                    midpoints[(label, neighbor_label)] = midpoint
                    midpoints[(neighbor_label, label)] = midpoint

    # create list of candidates
    positive_candidates = []
    negative_candidates = []
    undetermined_candidates = []

    for ie, match in enumerate(endpoint_pairs):
        print '{}/{}'.format(ie, len(endpoint_pairs))
        endpoint_one = endpoint_pairs[match][0]
        endpoint_two = endpoint_pairs[match][1]

        label_one = endpoint_one.label
        label_two = endpoint_two.label

        if label_two > label_one: continue

        # extract a bounding box around this midpoint
        midz, midy, midx = midpoints[(label_one, label_two)]

        zmin = max(0, midz - network_radii[IB_Z])
        ymin = max(0, midy - network_radii[IB_Y])
        xmin = max(0, midx - network_radii[IB_X])
        zmax = min(zres - 1, midz + network_radii[IB_Z] + 1)
        ymax = min(yres - 1, midy + network_radii[IB_Y] + 1)
        xmax = min(xres - 1, midx + network_radii[IB_X] + 1)

        extracted_segmentation = segmentation[zmin:zmax, ymin:ymax, xmin:xmax]
        extracted_gold = gold[zmin:zmax, ymin:ymax, xmin:xmax]

        extracted_seg2gold_mapping = seg2gold.Mapping(extracted_segmentation,
                                                      extracted_gold,
                                                      match_threshold=0.70,
                                                      nonzero_threshold=0.40)

        if label_one > extracted_seg2gold_mapping.size: continue
        if label_two > extracted_seg2gold_mapping.size: continue

        gold_one = extracted_seg2gold_mapping[label_one]
        gold_two = extracted_seg2gold_mapping[label_two]

        ground_truth = (gold_one == gold_two)

        candidate = SkeletonCandidate((label_one, label_two),
                                      midpoints[(label_one, label_two)],
                                      ground_truth)

        if not extracted_seg2gold_mapping[
                label_one] or not extracted_seg2gold_mapping[label_two]:
            undetermined_candidates.append(candidate)
        elif ground_truth:
            positive_candidates.append(candidate)
        else:
            negative_candidates.append(candidate)

    # save positive and negative candidates separately
    positive_filename = 'features/skeleton/{}-{}-{}nm-{}nm-{}nm-positive.candidates'.format(
        prefix, threshold, maximum_distance, endpoint_distance,
        network_distance)
    negative_filename = 'features/skeleton/{}-{}-{}nm-{}nm-{}nm-negative.candidates'.format(
        prefix, threshold, maximum_distance, endpoint_distance,
        network_distance)
    undetermined_filename = 'features/skeleton/{}-{}-{}nm-{}nm-{}nm-undetermined.candidates'.format(
        prefix, threshold, maximum_distance, endpoint_distance,
        network_distance)

    SaveCandidates(positive_filename, positive_candidates)
    SaveCandidates(negative_filename, negative_candidates)
    SaveCandidates(undetermined_filename, undetermined_candidates)

    print 'Positive candidates: {}'.format(len(positive_candidates))
    print 'Negative candidates: {}'.format(len(negative_candidates))
    print 'Undetermined candidates: {}'.format(len(undetermined_candidates))
Пример #9
0
def MergeGroundTruth(prefix, model_prefix):
    # read the segmentation data
    segmentation = dataIO.ReadSegmentationData(prefix)

    # get the multicut filename (with graph weights)
    multicut_filename = 'multicut/{}-{}.graph'.format(model_prefix, prefix)

    # read the gold data
    gold = dataIO.ReadGoldData(prefix)

    # read in the segmentation to gold mapping
    mapping = seg2gold.Mapping(segmentation, gold)

    # get the maximum segmentation value
    max_value = np.amax(segmentation)

    # create union find data structure
    union_find = [UnionFind.UnionFindElement(iv) for iv in range(max_value)]

    # read in all of the labels
    with open(multicut_filename, 'rb') as fd:
        # read the number of vertices and edges
        nvertices, nedges, = struct.unpack('QQ', fd.read(16))

        # read in all of the edges
        for ie in range(nedges):
            # read in the two labels
            label_one, label_two, = struct.unpack('QQ', fd.read(16))

            # skip over the reduced labels and edge weight
            fd.read(24)

            # if the labels are the same and the mapping is non zero
            if mapping[label_one] == mapping[label_two] and mapping[label_one]:
                UnionFind.Union(union_find[label_one], union_find[label_two])

    # create a mapping
    mapping = np.zeros(max_value, dtype=np.int64)

    # update the segmentation
    for iv in range(max_value):
        label = UnionFind.Find(union_find[iv]).label

        mapping[iv] = label

    merged_segmentation = seg2seg.MapLabels(segmentation, mapping)

    gold_filename = 'gold/{}_gold.h5'.format(prefix)

    # TODO fix this code temporary filename
    truth_filename = 'multicut/{}-truth.h5'.format(prefix)

    # temporary write h5 file
    dataIO.WriteH5File(merged_segmentation, truth_filename, 'stack')

    import time
    start_time = time.time()
    print 'Ground truth: '
    # create the command line
    command = '~/software/PixelPred2Seg/comparestacks --stack1 {} --stackbase {} --dilate1 1 --dilatebase 1 --relabel1 --relabelbase --filtersize 100 --anisotropic'.format(
        truth_filename, gold_filename)

    # execute the command
    os.system(command)
    print time.time() - start_time