コード例 #1
0
    def _process(self, im, image_scale: Param(0.5, (0.05, 1.0)),
                 filter_size: Param(2, (0, 15)), color_invert: Param(True),
                 clip: Param(140, (0, 255)), **extraparams):
        """Optionally resizes, smooths and inverts the image

        :param im:
        :param state:
        :param filter_size:
        :param image_scale:
        :param color_invert:
        :return:
        """
        if image_scale != 1:
            im = cv2.resize(im,
                            None,
                            fx=image_scale,
                            fy=image_scale,
                            interpolation=cv2.INTER_AREA)
        if filter_size > 0:
            im = cv2.boxFilter(im, -1, (filter_size, filter_size))
        if color_invert:
            im = 255 - im
        if clip > 0:
            im = np.maximum(im, clip) - clip

        if self.set_diagnostic == "filtered":
            self.diagnostic_image = im

        return NodeOutput([], im)
コード例 #2
0
ファイル: test_pipeline.py プロジェクト: tbenst/stytra
 def _process(self, input, a: Param(1), set_diagnostic=None):
     if self._output_type is None:
         self._output_type = namedtuple("o", "inp par")
     else:
         self._output_type_changed = False
     if self.set_diagnostic:
         self.diagnostic_image = "img"
     return NodeOutput([], self._output_type(par=a, inp=input))
コード例 #3
0
    def _process(
            self,
            im,
            learning_rate: Param(0.04, (0.0, 1.0)),
            learn_every: Param(400, (1, 10000)),
            only_darker: Param(True),
    ):
        messages = []
        if self.background_image is None:
            self.background_image = im.astype(np.float32)
            messages.append("I:New backgorund image set")
        elif self.i == 0:
            self.background_image[:, :] = im.astype(np.float32) * np.float32(
                learning_rate) + self.background_image * np.float32(
                    1 - learning_rate)

        self.i = (self.i + 1) % learn_every

        if only_darker:
            return NodeOutput(messages, negdif(self.background_image, im))
        else:
            return NodeOutput(messages, absdif(self.background_image, im))
コード例 #4
0
ファイル: test_pipeline.py プロジェクト: tbenst/stytra
def test_a_pipeline():
    p = TestPipeline()
    p.setup()
    tt = namedtuple("o", "inp par")
    assert p.run(None) == ([], tt(None, 1))
    assert p.diagnostic_image is None
    ser = p.serialize_params()
    print(ser)
    ser["/source/testnode"]["a"] = 2
    ser["diagnostics"]["image"] = "/source/testnode/processed"
    p.deserialize_params(ser)
    assert p.run(None) == NodeOutput([], tt(None, 2))
    assert p.diagnostic_image == "img"
コード例 #5
0
    def _process(self, image):
        # We put model instantiation here so it happens only on one process,
        # otherwise if it is in __init__, two tensorflow sessions would be
        # instantiated causing no end of problems
        if self.model is None:
            self.model = DLCmodel(self.dlc_cfg_path, self.model_path)
            self._output_type = namedtuple(
                "o",
                chain.from_iterable(([(p + "_x", p + "_y")
                                      for p in self.model.tracked_parts])))

        pose = self.model.predict_im(image)
        return NodeOutput([], self._output_type(*(pose[:, :2].flatten())))
コード例 #6
0
ファイル: eyes.py プロジェクト: maoyeh/stytra
    def _process(self, im, wnd_pos: Param((129, 20), gui=False),
                 threshold: Param(56, limits=(1, 254)), wnd_dim: Param(
                     (14, 22), gui=False), **extraparams):
        """

        Parameters
        ----------
        im :
            image (numpy array);
        win_pos :
            position of the window on the eyes (x, y);
        win_dim :
            dimension of the window on the eyes (w, h);
        threshold :
            threshold for ellipse fitting (int).

        Returns
        -------

        """
        message = ""
        PAD = 0

        cropped = _pad(
            (im[wnd_pos[1]:wnd_pos[1] + wnd_dim[1], wnd_pos[0]:wnd_pos[0] +
                wnd_dim[0], ] < threshold).view(dtype=np.uint8).copy(),
            padding=PAD,
            val=255,
        )

        # try:
        e = _fit_ellipse(cropped)

        if self.set_diagnostic == "thresholded":
            self.diagnostic_image = (im < threshold).view(dtype=np.uint8)

        if e is False:
            e = (np.nan, ) * 10
            message = "E: eyes not detected!"
        else:
            e = (e[0][0][::-1] + e[0][1][::-1] + (-e[0][2], ) + e[1][0][::-1] +
                 e[1][1][::-1] + (-e[1][2], ))
        return NodeOutput([
            message,
        ], self._output_type(*e))
コード例 #7
0
    def _process(
            self,
            bg,
            n_fish_max: Param(1, (1, 50)),
            n_segments: Param(10, (2, 30)),
            bg_downsample: Param(1, (1, 8)),
            bg_dif_threshold: Param(25, (0, 255)),
            threshold_eyes: Param(35, (0, 255)),
            pos_uncertainty:
        Param(
            1.0,
            (0, 10.0),
            desc=
            "Uncertainty in pixels about the location of the head center of mass",
        ),
            persist_fish_for:
        Param(
            2,
            (1, 50),
            desc=
            "How many frames does the fish persist for if it is not detected",
        ),
            prediction_uncertainty: Param(0.1, (0.0, 10.0, 0.0001)),
            fish_area: Param((200, 1200), (1, 4000)),
            border_margin: Param(5, (0, 100)),
            tail_length: Param(60.0, (1.0, 200.0)),
            tail_track_window: Param(3, (3, 70)),
    ):

        # update the previously-detected fish using the Kalman filter
        if self.fishes is None:
            self.reset()
        else:
            self.fishes.predict()

        area_scale = bg_downsample * bg_downsample
        border_margin = border_margin // bg_downsample

        # downsample background
        if bg_downsample > 1:
            bg_small = cv2.resize(bg,
                                  None,
                                  fx=1 / bg_downsample,
                                  fy=1 / bg_downsample)
        else:
            bg_small = bg

        bg_thresh = cv2.dilate(
            (bg_small > bg_dif_threshold).view(dtype=np.uint8),
            self.dilation_kernel)

        # find regions where there is a difference with the background
        n_comps, labels, stats, centroids = cv2.connectedComponentsWithStats(
            bg_thresh)

        try:
            max_area = np.max(stats[1:, cv2.CC_STAT_AREA]) * area_scale
        except ValueError:
            max_area = 0

        # iterate through all the regions different from the background and try
        # to find fish

        messages = []

        nofish = True
        for row, centroid in zip(stats, centroids):
            # check if the contour is fish-sized and central enough
            if not fish_area[0] < row[
                    cv2.CC_STAT_AREA] * area_scale < fish_area[1]:
                continue

            # find the bounding box of the fish in the original image coordinates
            ftop, fleft, fheight, fwidth = (int(round(row[x] * bg_downsample))
                                            for x in [
                                                cv2.CC_STAT_TOP,
                                                cv2.CC_STAT_LEFT,
                                                cv2.CC_STAT_HEIGHT,
                                                cv2.CC_STAT_WIDTH,
                                            ])

            if not ((fleft - border_margin >= 0) and
                    (fleft + fwidth + border_margin < bg.shape[1]) and
                    (ftop - border_margin >= 0) and
                    (ftop + fheight + border_margin < bg.shape[0])):
                messages.append(
                    "W:An object of right area found outside margins")
                continue

            # how much is this region shifted from the upper left corner of the image
            cent_shift = np.array(
                [fleft - border_margin, ftop - border_margin])

            slices = (
                slice(ftop - border_margin, ftop + fheight + border_margin),
                slice(fleft - border_margin, fleft + fwidth + border_margin),
            )

            # take the region and mask the background away to aid detection
            fishdet = bg[slices].copy()

            # estimate the position of the head
            fish_coords = fish_start(fishdet, threshold_eyes)

            # if no actual fish was found here, continue on to the next connected component
            if fish_coords[0] == -1:
                messages.append("W:No appropriate tail start position found")
                continue

            head_coords_up = fish_coords + cent_shift

            theta = _fish_direction_n(bg, head_coords_up,
                                      int(round(tail_length / 2)))

            # find the points of the tail
            points = find_fish_midline(
                bg,
                *head_coords_up,
                theta,
                tail_track_window,
                tail_length / n_segments,
                n_segments + 1,
            )

            # convert to angles
            angles = np.mod(points_to_angles(points) + np.pi,
                            np.pi * 2) - np.pi
            if len(angles) == 0:
                messages.append("W:Tail not completely detectable")
                continue

            # also, make the angles continuous
            angles[1:] = np.unwrap(angles[1:] - angles[0])

            # put the data together for one fish
            fish_coords = np.concatenate([np.array(points[0][:2]), angles])

            nofish = False
            # check if this is a new fish, or it is an update of
            # a fish detected previously
            if self.fishes.update(fish_coords):
                messages.append("I:Updated previous fish")
            elif self.fishes.add_fish(fish_coords):
                messages.append("I:Added new fish")
            else:
                messages.append("E:More fish than n_fish max")

        if nofish:
            messages.append(
                "W:No object of right area, between {:.0f} and {:.0f}".format(
                    *fish_area))

        # if a debugging image is to be shown, set it
        if self.set_diagnostic == "background difference":
            self.diagnostic_image = bg
        elif self.set_diagnostic == "thresholded background difference":
            self.diagnostic_image = bg_thresh
        elif self.set_diagnostic == "fish detection":
            fishdet = bg_small.copy()
            fishdet[bg_thresh == 0] = 0
            self.diagnostic_image = fishdet
        elif self.set_diagnostic == "thresholded for eye and swim bladder":
            self.diagnostic_image = np.maximum(bg,
                                               threshold_eyes) - threshold_eyes

        if self._output_type is None:
            self.reset_state()
        return NodeOutput(
            messages,
            self._output_type(*self.fishes.coords.flatten(), max_area * 1.0))
コード例 #8
0
    def _process(
        self,
        im,
        tail_start: Param((0.47, 1.7), gui=False),
        tail_length: Param((0.07, -1.36), gui=False),
        n_segments: Param(12, (1, 50)),
        tail_filter_width: Param(0.0, (0.0, 10.0)),
        time_filter_weight: Param(0.0, (0.0, 1.0)),
        n_output_segments: Param(9, (1, 30)),
        reset_zero: Param(False),
        window_size: Param(7, (1, 15)),
        **extraparams
    ):
        """Finds the tail for an embedded fish, given the starting point and
        the direction of the tail. Alternative to the sequential circular arches.

        Parameters
        ----------
        im :
            image to process
        tail_start :
            starting point (x, y) (Default value = 0)
        tail_length :
            tail length (x, y) (Default value = 1)
        n_segments :
            number of desired segments (Default value = 12)
        window_size :
            window size in pixel for center-of-mass calculation (Default value = 7)
        color_invert :
            True for inverting luminosity of the image (Default value = False)
        filter_size :
            Size of the box filter to low-pass filter the image (Default value = 0)
        image_scale :
            the amount of downscaling of the image (Default value = 0.5)
        0) :

        1) :


        Returns
        -------
        type
            list of cumulative sum + list of angles

        """
        messages = []
        start_y, start_x = tail_start
        tail_length_y, tail_length_x = tail_length

        scale = im.shape[0]

        # Calculate tail length:
        length_tail = np.sqrt(tail_length_x**2 + tail_length_y**2) * scale

        # Segment length from tail length and n of segments:
        seg_length = length_tail / n_segments

        n_segments += 1

        # Initial displacements in x and y:
        disp_x = tail_length_x * scale / n_segments
        disp_y = tail_length_y * scale / n_segments

        angles = np.full(n_segments - 1, np.nan)
        start_x *= scale
        start_y *= scale

        halfwin = window_size / 2
        for i in range(1, n_segments):
            # Use next segment function for find next point
            # with center-of-mass displacement:
            start_x, start_y, disp_x, disp_y, acc = _next_segment(
                im, start_x, start_y, disp_x, disp_y, halfwin, seg_length
            )
            if start_x < 0:
                messages.append("W:segment {} not detected".format(i))
                break

            abs_angle = np.arctan2(disp_x, disp_y)
            angles[i - 1] = abs_angle

        # we want angles to be continuous, this removes potential 2pi discontinuities
        angles = np.unwrap(angles)

        # we do not need to record a large amount of angles
        if tail_filter_width > 0:
            angles = gaussian_filter1d(angles, tail_filter_width, mode="nearest")

        angles = np.interp(
            np.linspace(0, 1, n_output_segments),
            np.linspace(0, 1, n_segments - 1),
            angles,
        )
        # Interpolate to the desired number of output segments

        if reset_zero:
            if self.resting_angles is None or len(self.resting_angles) != len(angles):
                self.resting_angles = angles
            else:
                self.resting_angles = self.resting_angles * 0.5 + angles * 0.5
        else:
            if self.resting_angles is not None:
                angles = angles - self.resting_angles + self.resting_angles[0]

        if time_filter_weight > 0 and self.previous_angles is not None:
            angles = (
                time_filter_weight * self.previous_angles
                + (1 - time_filter_weight) * angles
            )

        self.previous_angles = angles

        if self._output_type is None:
            self.reset()

        # Total curvature as sum of the last 2 angles - sum of the first 2
        return NodeOutput(
            messages,
            self._output_type(angles[-1] + angles[-2] - angles[0] - angles[1], *angles),
        )
コード例 #9
0
ファイル: custom_tracking_exp.py プロジェクト: tbenst/stytra
    def _process(
        self,
        im,
        threshold: Param(56, limits=(1, 254)),
        fly_area: Param((5, 1000), (1, 4000)),
        **extraparams
    ):
        """
        :param im: input image
        :param threshold: threshold for binarization
        :param fly_area: tuple with minimum and maximum size for the blob
        :return: NodeOutput with the tracking results
        """
        # Diagnostic messages can be outputted with info on what went wrong:
        message = ""

        # Binarize the image with the specified threshold:
        thesholded = (im[:, :] < threshold).view(dtype=np.uint8).copy()

        # Find contours with OpenCV:
        cont_ret = cv2.findContours(
            thesholded.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
        )

        # Small compatibility fix on OpenCV versions.
        # API change, in OpenCV 4 there are 2 values unlike OpenCV3
        if len(cont_ret) == 3:
            _, contours, hierarchy = cont_ret
        else:
            contours, hierarchy = cont_ret

        ell = False
        if len(contours) >= 2:

            # Get the largest ellipse:
            contour = sorted(contours, key=lambda c: c.shape[0], reverse=True)[0]

            # Fit the ellipse for the fly, if contours has a minimal length:
            if fly_area[0] < len(contour) < fly_area[1]:
                # ell will be a tuple ((y, x), (dim_y, dim_x), theta)
                ell = cv2.fitEllipse(contour)

                max_approx_radius = np.sqrt(fly_area[1] / np.pi) * 10
                if ell[1][0] > max_approx_radius or ell[1][1] > max_approx_radius:
                    # If ellipse axis much larger than max area set to false:
                    ell = False
                    message = "W:Wrong fit - fly close to borders?"
            else:
                # Otherwise, set a diagnostic message:
                message = "W:Blob area ouside the area range!"
        else:
            # No blobs found:
            message = "W:No contours found!"

        # Here we have the option to specify a diagnostic image if the
        # set_diagnostic attribute (set somewhere else in Stytra) is matching
        #  one of our options:
        if self.set_diagnostic == "input":
            # show the preprocessed, background-subtracted image
            self.diagnostic_image = im
        if self.set_diagnostic == "thresholded":
            # show the thresholded image:
            self.diagnostic_image = thesholded

        if ell is False:
            # If e is not valid, return tuple eof nans
            ell = (np.nan,) * 5
        else:
            # If valid, reshape it to a plain tuple:
            ell = ell[0][::-1] + ell[1][::-1] + (-ell[2],)

        # Return a NodeOutput object which combines the message and the
        # output named tuple created from the output type defined in the init
        #  and the tuple with our tracked values
        return NodeOutput([message], self._output_type(*ell))