def double_propagate(img, label, FIR_THRES, SEC_THRES, OPEN):
    thres = img > FIR_THRES
    thres2 = img > SEC_THRES
    open_img = opening(img, np.ones((OPEN, OPEN)))
    img = img - open_img
    labels_out, distances = propagate(img, label, thres, 1)
    label = labels_out
    labels_out, distances = propagate(img, label, thres2, 1)
    labels_out = np.uint16(labels_out)
    return labels_out
    def segment(self, img, include_intermediate_results=False, **kwargs):
        assert img.ndim == 3, 'Expecting 3D image, got shape {}'.format(
        img = ndi.median_filter(img, size=(1, 3, 3))
        img = img_as_float(img)
        img = util.invert(img)

        img_mz = img.max(axis=0)
        img_mz = exposure.rescale_intensity(img_mz, out_range=(0, 1))
        peaks, img_dog, sigmas = blob_dog(img_mz,

        img_pk = np.zeros(img_mz.shape, dtype=bool)
        img_pk[peaks[:, 0].astype(int), peaks[:, 1].astype(int)] = True
        img_pk = morphology.label(img_pk)

        # Get mask to conduct segmentation over
        img_pm = self.get_primary_object_mask(
            img, morphology.binary_dilation(img_pk > 0, morphology.disk(32)))

        img_dt = ndi.distance_transform_edt(img_pm)

        # Use propogation rather than watershed as it often captures a much more accurate boundary
        img_obj = propagate.propagate(img_mz, img_pk, img_pm,
        img_bnd = img_obj * segmentation.find_boundaries(
            img_obj, mode='inner', background=0)

        img_seg = [img_obj, img_obj, img_bnd, img_bnd]
        if include_intermediate_results:
            to_uint16 = lambda im: exposure.rescale_intensity(
                im, out_range='uint16').astype(np.uint16)
            img_seg += [

        # Stack and add new axis to give to (z, ch, h, w)
        img_seg = np.stack(img_seg)[np.newaxis]
        assert img_seg.dtype == np.uint16, 'Expecting 16bit result, got type {}'.format(
        assert img_seg.ndim == 4, 'Expecting 4D result, got shape {}'.format(
        return img_seg
def propagate_and_cleanup(img,
    label = all_label.copy()
    label[all_label > 0] = nuclabel.max() + 1
    label[nuclabel > 0] = 0
    label += nuclabel

    labels_out, distances = propagate(img, label, mask, prop_param)
    mask = calc_mask_exclude_overlap(all_label, nuc_overlap)
    labels_out[mask] = 0
    labels_out[labels_out == nuclabel.max()] = 0
    mask = calc_mask_exclude_overlap(labels_out, 1)
    labels_out[mask] = 0
    return labels_out
 def get_skeleton_points(self, obj):
     '''Get points by skeletonizing the objects and decimating'''
     ii = []
     jj = []
     total_skel = np.zeros(obj.shape, bool)
     for labels, indexes in obj.get_labels():
         colors = morph.color_labels(labels)
         for color in range(1, np.max(colors) + 1):
             labels_mask = colors == color
             skel = morph.skeletonize(
                 ordering=distance_transform_edt(labels_mask) *
             total_skel = total_skel | skel
     n_pts = np.sum(total_skel)
     if n_pts == 0:
         return np.zeros(0, np.int32), np.zeros(0, np.int32)
     i, j = np.where(total_skel)
     if n_pts > self.max_points.value:
         # Decimate the skeleton by finding the branchpoints in the
         # skeleton and propagating from those.
         markers = np.zeros(total_skel.shape, np.int32)
         branchpoints = \
             morph.branchpoints(total_skel) | morph.endpoints(total_skel)
         markers[branchpoints] = np.arange(np.sum(branchpoints)) + 1
         # We compute the propagation distance to that point, then impose
         # a slightly arbitarary order to get an unambiguous ordering
         # which should number the pixels in a skeleton branch monotonically
         ts_labels, distances = propagate(np.zeros(markers.shape), markers,
                                          total_skel, 1)
         order = np.lexsort((j, i, distances[i, j], ts_labels[i, j]))
         # Get a linear space of self.max_points elements with bounds at
         # 0 and len(order)-1 and use that to select the points.
         order = order[np.linspace(0,
                                   len(order) - 1,
         return i[order], j[order]
     return i, j
    def do_measurements(self, workspace, image_name, object_name,
                        center_object_name, center_choice,
                        bin_count_settings, dd):
        '''Perform the radial measurements on the image set

        workspace - workspace that holds images / objects
        image_name - make measurements on this image
        object_name - make measurements on these objects
        center_object_name - use the centers of these related objects as
                      the centers for radial measurements. None to use the
                      objects themselves.
        center_choice - the user's center choice for this object:
        bin_count_settings - the bin count settings group
        d - a dictionary for saving reusable partial results

        returns one statistics tuple per ring.
        assert isinstance(workspace, cpw.Workspace)
        assert isinstance(workspace.object_set, cpo.ObjectSet)
        bin_count = bin_count_settings.bin_count.value
        wants_scaled = bin_count_settings.wants_scaled.value
        maximum_radius = bin_count_settings.maximum_radius.value

        image = workspace.image_set.get_image(image_name,
        objects = workspace.object_set.get_objects(object_name)
        labels, pixel_data = cpo.crop_labels_and_image(objects.segmented,
        nobjects = np.max(objects.segmented)
        measurements = workspace.measurements
        assert isinstance(measurements, cpmeas.Measurements)
        heatmaps = {}
        for heatmap in self.heatmaps:
            if heatmap.object_name.get_objects_name() == object_name and \
                            image_name == heatmap.image_name.get_image_name() and \
                            heatmap.get_number_of_bins() == bin_count:
                dd[id(heatmap)] = \
                    heatmaps[MEASUREMENT_ALIASES[heatmap.measurement.value]] = \
        if nobjects == 0:
            for bin in range(1, bin_count + 1):
                for feature in (F_FRAC_AT_D, F_MEAN_FRAC, F_RADIAL_CV):
                    feature_name = (
                        (feature + FF_GENERIC) % (image_name, bin, bin_count))
                            object_name, "_".join([M_CATEGORY, feature_name]),
                    if not wants_scaled:
                        measurement_name = "_".join([M_CATEGORY, feature,
                                                     image_name, FF_OVERFLOW])
                                object_name, measurement_name, np.zeros(0))
            return [(image_name, object_name, "no objects", "-", "-", "-", "-")]
        name = (object_name if center_object_name is None
                else "%s_%s" % (object_name, center_object_name))
        if dd.has_key(name):
            normalized_distance, i_center, j_center, good_mask = dd[name]
            d_to_edge = distance_to_edge(labels)
            if center_object_name is not None:
                # Use the center of the centering objects to assign a center
                # to each labeled pixel using propagation
                center_objects = workspace.object_set.get_objects(center_object_name)
                center_labels, cmask = cpo.size_similarly(
                        labels, center_objects.segmented)
                pixel_counts = fix(scind.sum(
                        np.arange(1, np.max(center_labels) + 1, dtype=np.int32)))
                good = pixel_counts > 0
                i, j = (centers_of_labels(center_labels) + .5).astype(int)
                ig = i[good]
                jg = j[good]
                lg = np.arange(1, len(i) + 1)[good]
                if center_choice == C_CENTERS_OF_OTHER:
                    # Reduce the propagation labels to the centers of
                    # the centering objects
                    center_labels = np.zeros(center_labels.shape, int)
                    center_labels[ig, jg] = lg
                cl, d_from_center = propagate(np.zeros(center_labels.shape),
                                              labels != 0, 1)
                # Erase the centers that fall outside of labels
                cl[labels == 0] = 0
                # If objects are hollow or crescent-shaped, there may be
                # objects without center labels. As a backup, find the
                # center that is the closest to the center of mass.
                missing_mask = (labels != 0) & (cl == 0)
                missing_labels = np.unique(labels[missing_mask])
                if len(missing_labels):
                    all_centers = centers_of_labels(labels)
                    missing_i_centers, missing_j_centers = \
                        all_centers[:, missing_labels - 1]
                    di = missing_i_centers[:, np.newaxis] - ig[np.newaxis, :]
                    dj = missing_j_centers[:, np.newaxis] - jg[np.newaxis, :]
                    missing_best = lg[np.argsort((di * di + dj * dj,))[:, 0]]
                    best = np.zeros(np.max(labels) + 1, int)
                    best[missing_labels] = missing_best
                    cl[missing_mask] = best[labels[missing_mask]]
                    # Now compute the crow-flies distance to the centers
                    # of these pixels from whatever center was assigned to
                    # the object.
                    iii, jjj = np.mgrid[0:labels.shape[0], 0:labels.shape[1]]
                    di = iii[missing_mask] - i[cl[missing_mask] - 1]
                    dj = jjj[missing_mask] - j[cl[missing_mask] - 1]
                    d_from_center[missing_mask] = np.sqrt(di * di + dj * dj)
                # Find the point in each object farthest away from the edge.
                # This does better than the centroid:
                # * The center is within the object
                # * The center tends to be an interesting point, like the
                #   center of the nucleus or the center of one or the other
                #   of two touching cells.
                i, j = maximum_position_of_labels(d_to_edge, labels, objects.indices)
                center_labels = np.zeros(labels.shape, int)
                center_labels[i, j] = labels[i, j]
                # Use the coloring trick here to process touching objects
                # in separate operations
                colors = color_labels(labels)
                ncolors = np.max(colors)
                d_from_center = np.zeros(labels.shape)
                cl = np.zeros(labels.shape, int)
                for color in range(1, ncolors + 1):
                    mask = colors == color
                    l, d = propagate(np.zeros(center_labels.shape),
                                     mask, 1)
                    d_from_center[mask] = d[mask]
                    cl[mask] = l[mask]
            good_mask = cl > 0
            if center_choice == C_EDGES_OF_OTHER:
                # Exclude pixels within the centering objects
                # when performing calculations from the centers
                good_mask = good_mask & (center_labels == 0)
            i_center = np.zeros(cl.shape)
            i_center[good_mask] = i[cl[good_mask] - 1]
            j_center = np.zeros(cl.shape)
            j_center[good_mask] = j[cl[good_mask] - 1]

            normalized_distance = np.zeros(labels.shape)
            if wants_scaled:
                total_distance = d_from_center + d_to_edge
                normalized_distance[good_mask] = (d_from_center[good_mask] /
                                                  (total_distance[good_mask] + .001))
                normalized_distance[good_mask] = \
                    d_from_center[good_mask] / maximum_radius
            dd[name] = [normalized_distance, i_center, j_center, good_mask]
        ngood_pixels = np.sum(good_mask)
        good_labels = labels[good_mask]
        bin_indexes = (normalized_distance * bin_count).astype(int)
        bin_indexes[bin_indexes > bin_count] = bin_count
        labels_and_bins = (good_labels - 1, bin_indexes[good_mask])
        histogram = coo_matrix((pixel_data[good_mask], labels_and_bins),
                               (nobjects, bin_count + 1)).toarray()
        sum_by_object = np.sum(histogram, 1)
        sum_by_object_per_bin = np.dstack([sum_by_object] * (bin_count + 1))[0]
        fraction_at_distance = histogram / sum_by_object_per_bin
        number_at_distance = coo_matrix((np.ones(ngood_pixels), labels_and_bins),
                                        (nobjects, bin_count + 1)).toarray()
        object_mask = number_at_distance > 0
        sum_by_object = np.sum(number_at_distance, 1)
        sum_by_object_per_bin = np.dstack([sum_by_object] * (bin_count + 1))[0]
        fraction_at_bin = number_at_distance / sum_by_object_per_bin
        mean_pixel_fraction = fraction_at_distance / (fraction_at_bin +
        masked_fraction_at_distance = masked_array(fraction_at_distance,
        masked_mean_pixel_fraction = masked_array(mean_pixel_fraction,
        # Anisotropy calculation.  Split each cell into eight wedges, then
        # compute coefficient of variation of the wedges' mean intensities
        # in each ring.
        # Compute each pixel's delta from the center object's centroid
        i, j = np.mgrid[0:labels.shape[0], 0:labels.shape[1]]
        imask = i[good_mask] > i_center[good_mask]
        jmask = j[good_mask] > j_center[good_mask]
        absmask = (abs(i[good_mask] - i_center[good_mask]) >
                   abs(j[good_mask] - j_center[good_mask]))
        radial_index = (imask.astype(int) + jmask.astype(int) * 2 +
                        absmask.astype(int) * 4)
        statistics = []

        for bin in range(bin_count + (0 if wants_scaled else 1)):
            bin_mask = (good_mask & (bin_indexes == bin))
            bin_pixels = np.sum(bin_mask)
            bin_labels = labels[bin_mask]
            bin_radial_index = radial_index[bin_indexes[good_mask] == bin]
            labels_and_radii = (bin_labels - 1, bin_radial_index)
            radial_values = coo_matrix((pixel_data[bin_mask],
                                       (nobjects, 8)).toarray()
            pixel_count = coo_matrix((np.ones(bin_pixels), labels_and_radii),
                                     (nobjects, 8)).toarray()
            mask = pixel_count == 0
            radial_means = masked_array(radial_values / pixel_count, mask)
            radial_cv = np.std(radial_means, 1) / np.mean(radial_means, 1)
            radial_cv[np.sum(~mask, 1) == 0] = 0
            for measurement, feature, overflow_feature in (
                    (fraction_at_distance[:, bin], MF_FRAC_AT_D, OF_FRAC_AT_D),
                    (mean_pixel_fraction[:, bin], MF_MEAN_FRAC, OF_MEAN_FRAC),
                    (np.array(radial_cv), MF_RADIAL_CV, OF_RADIAL_CV)):

                if bin == bin_count:
                    measurement_name = overflow_feature % image_name
                    measurement_name = feature % (image_name, bin + 1, bin_count)
                if feature in heatmaps:
                    heatmaps[feature][bin_mask] = measurement[bin_labels - 1]
            radial_cv.mask = np.sum(~mask, 1) == 0
            bin_name = str(bin + 1) if bin < bin_count else "Overflow"
            statistics += [(image_name, object_name, bin_name, str(bin_count),
                            round(np.mean(masked_fraction_at_distance[:, bin]), 4),
                            round(np.mean(masked_mean_pixel_fraction[:, bin]), 4),
                            round(np.mean(radial_cv), 4))]
        return statistics
    def segment(self, img_nuc, img_memb=None,
                marker_dilation=1, marker_min_size=16,
                memb_min_dist=5, memb_max_dist=10, memb_hole_size=16,
                memb_sigma=1, memb_gamma=None, memb_tresh_method='li', memb_propagation_regularization=.05,
                batch_size=DEFAULT_BATCH_SIZE, return_masks=False):
        if not self.initialized:

        if (memb_min_dist or 0) >= (memb_max_dist or np.inf):
            raise ValueError(
                'Membrane min distance argument (memb_min_dist = {}) used to set minimum cell boundary '
                'must be <= maximum cell boundary distance from nucleus (memb_max_dist = {})'
                .format(memb_min_dist, memb_max_dist)

        # Convert images to segment or otherwise analyze to 8-bit
        img_nuc = _to_uint8(img_nuc, 'nucleus')
        if img_memb is not None:
            img_memb = _to_uint8(img_memb, 'membrane')

        # Add z dimension (equivalent to batch dim in this case) if not present
        if img_nuc.ndim == 2:
            img_nuc = np.expand_dims(img_nuc, 0)
        if img_nuc.ndim != 3:
            raise ValueError('Must provide image as ZHW or HW (image shape given = {})'.format(img_nuc.shape))

        # Make predictions on image converted to 0-1 and with trailing channel dimension to give NHWC;
        # Result has shape NHWC where C=3 and C1 = bg, C2 = interior, C3 = border
        img_pred = self.predict(np.expand_dims(img_nuc / 255., -1), batch_size)
        assert img_pred.shape[-1] == 3, \
            'Expecting 3 outputs in predictions (shape = {})'.format(img_pred.shape)

        img_seg_list, img_bin_list = [], []
        nz = img_nuc.shape[0]
        for i in range(nz):

            # Use nuclei interior mask as watershed markers
            img_bin_nucm = np.argmax(img_pred[i], axis=-1) == 1

            # Remove markers (which determine number of cells) below the given size
            if marker_min_size > 0:
                img_bin_nucm = morphology.remove_small_objects(img_bin_nucm, min_size=marker_min_size)

            # Define the entire nucleus as a slight dilation of the markers noting that this
            # works better than using the union of predicted interiors and predicted boundaries
            # (which are too thick)
            img_bin_nuci = img_bin_nucm
            if marker_dilation > 0:
                img_bin_nuci = cv2.dilate(
                    img_bin_nucm.astype(np.uint8), morphology.disk(marker_dilation)).astype(np.bool)

            # Determine the overall mask to segment across by dilating nuclei by some fixed amount
            # or if possible, using the given cell membrane image
            img_bin_mask = self.get_segmentation_mask(
                img_bin_nuci, img_memb=img_memb[i] if img_memb is not None else None,
                min_dist=memb_min_dist, max_dist=memb_max_dist, hole_size=memb_hole_size,
                method=memb_tresh_method, sigma=memb_sigma, gamma=memb_gamma)

            # Label the nuclei markers (which determines number of cells to output)
            # *Note: It is important to keep this separate from nuclei interior as single or double pixel
            # gaps between nuclei are common when densely packed
            img_bin_nucm_label = morphology.label(img_bin_nucm)

            # Create labeled cell image
            if img_memb is None or memb_propagation_regularization is None:
                # Run watershed using markers and expanded nuclei / cell mask
                img_basin = -1 * ndimage.distance_transform_edt(img_bin_nucm)
                img_cell_seg = segmentation.watershed(img_basin, img_bin_nucm_label, mask=img_bin_mask)
                # Before running propagation segmentation, make sure that the input image is 0-1 float
                # as the regularization threshold is calibrated to work only with data in that range
                img_cell_seg, _ = propagate.propagate(
                    img_as_float(img_memb[i]), img_bin_nucm_label,
                    img_bin_mask, memb_propagation_regularization

            # Generate nucleus segmentation based on cell segmentation and nucleus mask
            # and relabel nuclei objections using corresponding cell labels
            img_nuc_seg = (img_cell_seg > 0) & img_bin_nuci
            img_nuc_seg = img_nuc_seg * img_cell_seg

            # Add labeled images to results
            assert img_cell_seg.dtype == img_nuc_seg.dtype, \
                'Cell segmentation dtype {} != nucleus segmentation dtype {}'\
                .format(img_cell_seg.dtype, img_nuc_seg.dtype)
            img_seg_list.append(np.stack([img_cell_seg, img_nuc_seg], axis=0))

            # Add mask images to results, if requested
            if return_masks:
                img_bin_list.append(np.stack([img_bin_nuci, img_bin_nucm, img_bin_mask], axis=0))

        assert nz == len(img_seg_list)
        if return_masks:
            assert nz == len(img_bin_list)

        # Stack final segmentation image as (z, c, h, w)
        img_seg = np.stack(img_seg_list, axis=0)
        img_bin = np.stack(img_bin_list, axis=0) if return_masks else None
        assert img_seg.ndim == 4, 'Expecting 4D segmentation image but shape is {}'.format(img_seg.shape)

        # Return (in this order) labeled volumes, prediction volumes, mask volumes
        return img_seg, img_pred, img_bin
    def run(self, workspace):
        '''Run the module on the image set'''
        seed_objects_name = self.seed_objects_name.value
        skeleton_name = self.image_name.value
        seed_objects = workspace.object_set.get_objects(seed_objects_name)
        labels = seed_objects.segmented
        labels_count = np.max(labels)
        label_range = np.arange(labels_count, dtype=np.int32) + 1

        skeleton_image = workspace.image_set.get_image(
                skeleton_name, must_be_binary=True)
        skeleton = skeleton_image.pixel_data
        if skeleton_image.has_mask:
            skeleton = skeleton & skeleton_image.mask
            labels = skeleton_image.crop_image_similarly(labels)
            labels, m1 = cpo.size_similarly(skeleton, labels)
            labels[~m1] = 0
        # The following code makes a ring around the seed objects with
        # the skeleton trunks sticking out of it.
        # Create a new skeleton with holes at the seed objects
        # First combine the seed objects with the skeleton so
        # that the skeleton trunks come out of the seed objects.
        # Erode the labels once so that all of the trunk branchpoints
        # will be within the labels
        # Dilate the objects, then subtract them to make a ring
        my_disk = morph.strel_disk(1.5).astype(int)
        dilated_labels = grey_dilation(labels, footprint=my_disk)
        seed_mask = dilated_labels > 0
        combined_skel = skeleton | seed_mask

        closed_labels = grey_erosion(dilated_labels,
        seed_center = closed_labels > 0
        combined_skel = combined_skel & (~seed_center)
        # Fill in single holes (but not a one-pixel hole made by
        # a one-pixel image)
        if self.wants_to_fill_holes:
            def size_fn(area, is_object):
                return (~ is_object) and (area <= self.maximum_hole_size.value)

            combined_skel = morph.fill_labeled_holes(
                    combined_skel, ~seed_center, size_fn)
        # Reskeletonize to make true branchpoints at the ring boundaries
        combined_skel = morph.skeletonize(combined_skel)
        # The skeleton outside of the labels
        outside_skel = combined_skel & (dilated_labels == 0)
        # Associate all skeleton points with seed objects
        dlabels, distance_map = propagate.propagate(np.zeros(labels.shape),
                                                    combined_skel, 1)
        # Get rid of any branchpoints not connected to seeds
        combined_skel[dlabels == 0] = False
        # Find the branchpoints
        branch_points = morph.branchpoints(combined_skel)
        # Odd case: when four branches meet like this, branchpoints are not
        # assigned because they are arbitrary. So assign them.
        # .  .
        #  B.
        #  .B
        # .  .
        odd_case = (combined_skel[:-1, :-1] & combined_skel[1:, :-1] &
                    combined_skel[:-1, 1:] & combined_skel[1, 1])
        branch_points[:-1, :-1][odd_case] = True
        branch_points[1:, 1:][odd_case] = True
        # Find the branching counts for the trunks (# of extra branches
        # eminating from a point other than the line it might be on).
        branching_counts = morph.branchings(combined_skel)
        branching_counts = np.array([0, 0, 0, 1, 2])[branching_counts]
        # Only take branches within 1 of the outside skeleton
        dilated_skel = scind.binary_dilation(outside_skel, morph.eight_connect)
        branching_counts[~dilated_skel] = 0
        # Find the endpoints
        end_points = morph.endpoints(combined_skel)
        # We use two ranges for classification here:
        # * anything within one pixel of the dilated image is a trunk
        # * anything outside of that range is a branch
        nearby_labels = dlabels.copy()
        nearby_labels[distance_map > 1.5] = 0

        outside_labels = dlabels.copy()
        outside_labels[nearby_labels > 0] = 0
        # The trunks are the branchpoints that lie within one pixel of
        # the dilated image.
        if labels_count > 0:
            trunk_counts = fix(scind.sum(branching_counts, nearby_labels,
            trunk_counts = np.zeros((0,), int)
        # The branches are the branchpoints that lie outside the seed objects
        if labels_count > 0:
            branch_counts = fix(scind.sum(branch_points, outside_labels,
            branch_counts = np.zeros((0,), int)
        # Save the endpoints
        if labels_count > 0:
            end_counts = fix(scind.sum(end_points, outside_labels, label_range))
            end_counts = np.zeros((0,), int)
        # Calculate the distances
        total_distance = morph.skeleton_length(
                dlabels * outside_skel, label_range)
        # Save measurements
        m = workspace.measurements
        assert isinstance(m, cpmeas.Measurements)
        feature = "_".join((C_NEURON, F_NUMBER_TRUNKS, skeleton_name))
        m.add_measurement(seed_objects_name, feature, trunk_counts)
        feature = "_".join((C_NEURON, F_NUMBER_NON_TRUNK_BRANCHES,
        m.add_measurement(seed_objects_name, feature, branch_counts)
        feature = "_".join((C_NEURON, F_NUMBER_BRANCH_ENDS, skeleton_name))
        m.add_measurement(seed_objects_name, feature, end_counts)
        feature = "_".join((C_NEURON, F_TOTAL_NEURITE_LENGTH, skeleton_name))
        m[seed_objects_name, feature] = total_distance
        # Collect the graph information
        if self.wants_neuron_graph:
            trunk_mask = (branching_counts > 0) & (nearby_labels != 0)
            intensity_image = workspace.image_set.get_image(
            edge_graph, vertex_graph = self.make_neuron_graph(
                    combined_skel, dlabels,
                    branch_points & ~trunk_mask,

            image_number = workspace.measurements.image_set_number

            edge_path, vertex_path = self.get_graph_file_paths(m, m.image_number)
                    self, m.image_number, edge_path, edge_graph,
                    vertex_path, vertex_graph, headless_ok=True)

            if self.show_window:
                workspace.display_data.edge_graph = edge_graph
                workspace.display_data.vertex_graph = vertex_graph
                workspace.display_data.intensity_image = intensity_image.pixel_data
        # Make the display image
        if self.show_window or self.wants_branchpoint_image:
            branchpoint_image = np.zeros((skeleton.shape[0],
            trunk_mask = (branching_counts > 0) & (nearby_labels != 0)
            branch_mask = branch_points & (outside_labels != 0)
            end_mask = end_points & (outside_labels != 0)
            branchpoint_image[outside_skel, :] = 1
            branchpoint_image[trunk_mask | branch_mask | end_mask, :] = 0
            branchpoint_image[trunk_mask, 0] = 1
            branchpoint_image[branch_mask, 1] = 1
            branchpoint_image[end_mask, 2] = 1
            branchpoint_image[dilated_labels != 0, :] *= .875
            branchpoint_image[dilated_labels != 0, :] += .1
            if self.show_window:
                workspace.display_data.branchpoint_image = branchpoint_image
            if self.wants_branchpoint_image:
                bi = cpi.Image(branchpoint_image,
                workspace.image_set.add(self.branchpoint_image_name.value, bi)
    def run(self, workspace):
        assert isinstance(workspace, cpw.Workspace)
        image_name = self.image_name.value
        image = workspace.image_set.get_image(image_name,
        workspace.display_data.statistics = []
        img = image.pixel_data
        mask = image.mask
        objects = workspace.object_set.get_objects(self.primary_objects.value)
        global_threshold = None
        if self.method == M_DISTANCE_N:
            has_threshold = False
            thresholded_image = self.threshold_image(image_name, workspace)
            has_threshold = True

        # Get the following labels:
        # * all edited labels
        # * labels touching the edge, including small removed
        labels_in = objects.unedited_segmented.copy()
        labels_touching_edge = np.hstack(
            (labels_in[0, :], labels_in[-1, :], labels_in[:,
                                                          0], labels_in[:,
        labels_touching_edge = np.unique(labels_touching_edge)
        is_touching = np.zeros(np.max(labels_in) + 1, bool)
        is_touching[labels_touching_edge] = True
        is_touching = is_touching[labels_in]

        labels_in[(~is_touching) & (objects.segmented == 0)] = 0
        # Stretch the input labels to match the image size. If there's no
        # label matrix, then there's no label in that area.
        if tuple(labels_in.shape) != tuple(img.shape):
            tmp = np.zeros(img.shape, labels_in.dtype)
            i_max = min(img.shape[0], labels_in.shape[0])
            j_max = min(img.shape[1], labels_in.shape[1])
            tmp[:i_max, :j_max] = labels_in[:i_max, :j_max]
            labels_in = tmp

        if self.method in (M_DISTANCE_B, M_DISTANCE_N):
            if self.method == M_DISTANCE_N:
                distances, (i, j) = scind.distance_transform_edt(
                    labels_in == 0, return_indices=True)
                labels_out = np.zeros(labels_in.shape, int)
                dilate_mask = distances <= self.distance_to_dilate.value
                labels_out[dilate_mask] =\
                labels_out, distances = propagate(img, labels_in,
                                                  thresholded_image, 1.0)
                labels_out[distances > self.distance_to_dilate.value] = 0
                labels_out[labels_in > 0] = labels_in[labels_in > 0]
            if self.fill_holes:
                small_removed_segmented_out = fill_labeled_holes(labels_out)
                small_removed_segmented_out = labels_out
            # Create the final output labels by removing labels in the
            # output matrix that are missing from the segmented image
            segmented_labels = objects.segmented
            segmented_out = self.filter_labels(small_removed_segmented_out,
                                               objects, workspace)
        elif self.method == M_PROPAGATION:
            labels_out, distance = propagate(img, labels_in, thresholded_image,
            if self.fill_holes:
                small_removed_segmented_out = fill_labeled_holes(labels_out)
                small_removed_segmented_out = labels_out.copy()
            segmented_out = self.filter_labels(small_removed_segmented_out,
                                               objects, workspace)
        elif self.method == M_WATERSHED_G:
            # First, apply the sobel filter to the image (both horizontal
            # and vertical). The filter measures gradient.
            sobel_image = np.abs(scind.sobel(img))
            # Combine the image mask and threshold to mask the watershed
            watershed_mask = np.logical_or(thresholded_image, labels_in > 0)
            watershed_mask = np.logical_and(watershed_mask, mask)
            # Perform the first watershed
            labels_out = watershed(sobel_image,
                                   np.ones((3, 3), bool),
            if self.fill_holes:
                small_removed_segmented_out = fill_labeled_holes(labels_out)
                small_removed_segmented_out = labels_out.copy()
            segmented_out = self.filter_labels(small_removed_segmented_out,
                                               objects, workspace)
        elif self.method == M_WATERSHED_I:
            # invert the image so that the maxima are filled first
            # and the cells compete over what's close to the threshold
            inverted_img = 1 - img
            # Same as above, but perform the watershed on the original image
            watershed_mask = np.logical_or(thresholded_image, labels_in > 0)
            watershed_mask = np.logical_and(watershed_mask, mask)
            # Perform the watershed
            labels_out = watershed(inverted_img,
                                   np.ones((3, 3), bool),
            if self.fill_holes:
                small_removed_segmented_out = fill_labeled_holes(labels_out)
                small_removed_segmented_out = labels_out
            segmented_out = self.filter_labels(small_removed_segmented_out,
                                               objects, workspace)

        if self.wants_discard_edge and self.wants_discard_primary:
            # Make a new primary object
            lookup = scind.maximum(segmented_out, objects.segmented,
                                   range(np.max(objects.segmented) + 1))
            lookup = fix(lookup)
            lookup[0] = 0
            lookup[lookup != 0] = np.arange(np.sum(lookup != 0)) + 1
            segmented_labels = lookup[objects.segmented]
            segmented_out = lookup[segmented_out]
            new_objects = cpo.Objects()
            new_objects.segmented = segmented_labels
            if objects.has_unedited_segmented:
                new_objects.unedited_segmented = objects.unedited_segmented
            if objects.has_small_removed_segmented:
                new_objects.small_removed_segmented = objects.small_removed_segmented
            new_objects.parent_image = objects.parent_image
            primary_outline = outline(segmented_labels)
            if self.wants_primary_outlines:
                out_img = cpi.Image(primary_outline.astype(bool),
            primary_outline = outline(objects.segmented)
        secondary_outline = outline(segmented_out)

        # Add the objects to the object set
        objects_out = cpo.Objects()
        objects_out.unedited_segmented = small_removed_segmented_out
        objects_out.small_removed_segmented = small_removed_segmented_out
        objects_out.segmented = segmented_out
        objects_out.parent_image = image
        objname = self.objects_name.value
        workspace.object_set.add_objects(objects_out, objname)
        if self.use_outlines.value:
            out_img = cpi.Image(secondary_outline.astype(bool),
            workspace.image_set.add(self.outlines_name.value, out_img)
        object_count = np.max(segmented_out)
        # Add measurements
        measurements = workspace.measurements
        cpmi.add_object_count_measurements(measurements, objname, object_count)
        cpmi.add_object_location_measurements(measurements, objname,
        # Relate the secondary objects to the primary ones and record
        # the relationship.
        children_per_parent, parents_of_children = \
                                     cpmi.FF_CHILDREN_COUNT % objname,
            objname, cpmi.FF_PARENT % self.primary_objects.value,
        image_numbers = np.ones(len(parents_of_children), int) *\
        mask = parents_of_children > 0
            self.module_num, R_PARENT, self.primary_objects.value,
            self.objects_name.value, image_numbers[mask],
            parents_of_children[mask], image_numbers[mask],
                      len(parents_of_children) + 1)[mask])
        # If primary objects were created, add them
        if self.wants_discard_edge and self.wants_discard_primary:
                new_objects, self.new_primary_objects_name.value)
                measurements, self.new_primary_objects_name.value,
                measurements, self.new_primary_objects_name.value,
            for parent_objects, parent_name, child_objects, child_name in (
                (objects, self.primary_objects.value, new_objects,
                (new_objects, self.new_primary_objects_name.value, objects_out,
                children_per_parent, parents_of_children = \
                    parent_name, cpmi.FF_CHILDREN_COUNT % child_name,
                                             cpmi.FF_PARENT % parent_name,
        if self.show_window:
            object_area = np.sum(segmented_out > 0)
            workspace.display_data.object_pct = \
                100 * object_area / np.product(segmented_out.shape)
            workspace.display_data.img = img
            workspace.display_data.segmented_out = segmented_out
            workspace.display_data.primary_labels = objects.segmented
            workspace.display_data.global_threshold = global_threshold
            workspace.display_data.object_count = object_count
