示例#1
0
class Tracker:
    """Wraps the tracker for evaluation and running purposes.
    args:
        name: Name of tracking method.
        parameter_name: Name of parameter file.
        run_id: The run id.
        display_name: Name to be displayed in the result plots.
    """

    def __init__(self, name: str, parameter_name: str, run_id: int = None, display_name: str = None):
        assert run_id is None or isinstance(run_id, int)

        self.name = name
        self.parameter_name = parameter_name
        self.run_id = run_id
        self.display_name = display_name

        env = env_settings()
        if self.run_id is None:
            self.results_dir = '{}/{}/{}'.format(env.results_path, self.name, self.parameter_name)
            self.segmentation_dir = '{}/{}/{}'.format(env.segmentation_path, self.name, self.parameter_name)
        else:
            self.results_dir = '{}/{}/{}_{:03d}'.format(env.results_path, self.name, self.parameter_name, self.run_id)
            self.segmentation_dir = '{}/{}/{}_{:03d}'.format(env.segmentation_path, self.name, self.parameter_name, self.run_id)

        tracker_module_abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tracker', self.name))
        if os.path.isdir(tracker_module_abspath):
            tracker_module = importlib.import_module('pytracking.tracker.{}'.format(self.name))
            self.tracker_class = tracker_module.get_tracker_class()
        else:
            self.tracker_class = None

        self.visdom = None


    def _init_visdom(self, visdom_info, debug):
        visdom_info = {} if visdom_info is None else visdom_info
        self.pause_mode = False
        self.step = False
        if debug > 0 and visdom_info.get('use_visdom', True):
            try:
                self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'},
                                     visdom_info=visdom_info)

                # Show help
                help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \
                            'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \
                            'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \
                            'block list.'
                self.visdom.register(help_text, 'text', 1, 'Help')
            except:
                time.sleep(0.5)
                print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n'
                      '!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!')

    def _visdom_ui_handler(self, data):
        if data['event_type'] == 'KeyPress':
            if data['key'] == ' ':
                self.pause_mode = not self.pause_mode

            elif data['key'] == 'ArrowRight' and self.pause_mode:
                self.step = True


    def create_tracker(self, params):
        tracker = self.tracker_class(params)
        tracker.visdom = self.visdom
        return tracker

    def run_sequence(self, seq, visualization=None, debug=None, visdom_info=None, multiobj_mode=None):
        """Run tracker on sequence.
        args:
            seq: Sequence to run the tracker on.
            visualization: Set visualization flag (None means default value specified in the parameters).
            debug: Set debug level (None means default value specified in the parameters).
            visdom_info: Visdom info.
            multiobj_mode: Which mode to use for multiple objects.
        """
        params = self.get_parameters()
        visualization_ = visualization

        debug_ = debug
        if debug is None:
            debug_ = getattr(params, 'debug', 0)
        if visualization is None:
            if debug is None:
                visualization_ = getattr(params, 'visualization', False)
            else:
                visualization_ = True if debug else False

        params.visualization = visualization_
        params.debug = debug_

        self._init_visdom(visdom_info, debug_)
        if visualization_ and self.visdom is None:
            self.init_visualization()

        # Get init information
        init_info = seq.init_info()
        is_single_object = not seq.multiobj_mode

        if multiobj_mode is None:
            multiobj_mode = getattr(params, 'multiobj_mode', getattr(self.tracker_class, 'multiobj_mode', 'default'))

        if multiobj_mode == 'default' or is_single_object:
            tracker = self.create_tracker(params)
        elif multiobj_mode == 'parallel':
            tracker = MultiObjectWrapper(self.tracker_class, params, self.visdom)
        else:
            raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))

        output = self._track_sequence(tracker, seq, init_info)
        return output

    def _track_sequence(self, tracker, seq, init_info):
        # Define outputs
        # Each field in output is a list containing tracker prediction for each frame.

        # In case of single object tracking mode:
        # target_bbox[i] is the predicted bounding box for frame i
        # time[i] is the processing time for frame i
        # segmentation[i] is the segmentation mask for frame i (numpy array)

        # In case of multi object tracking mode:
        # target_bbox[i] is an OrderedDict, where target_bbox[i][obj_id] is the predicted box for target obj_id in
        # frame i
        # time[i] is either the processing time for frame i, or an OrderedDict containing processing times for each
        # object in frame i
        # segmentation[i] is the multi-label segmentation mask for frame i (numpy array)

        output = {'target_bbox': [],
                  'time': [],
                  'segmentation': []}

        def _store_outputs(tracker_out: dict, defaults=None):
            defaults = {} if defaults is None else defaults
            for key in output.keys():
                val = tracker_out.get(key, defaults.get(key, None))
                if key in tracker_out or val is not None:
                    output[key].append(val)

        # Initialize
        image = self._read_image(seq.frames[0])

        if tracker.params.visualization and self.visdom is None:
            self.visualize(image, init_info.get('init_bbox'))

        start_time = time.time()
        out = tracker.initialize(image, init_info)
        if out is None:
            out = {}

        prev_output = OrderedDict(out)

        init_default = {'target_bbox': init_info.get('init_bbox'),
                        'time': time.time() - start_time,
                        'segmentation': init_info.get('init_mask')}

        _store_outputs(out, init_default)

        for frame_num, frame_path in enumerate(seq.frames[1:], start=1):
            while True:
                if not self.pause_mode:
                    break
                elif self.step:
                    self.step = False
                    break
                else:
                    time.sleep(0.1)

            image = self._read_image(frame_path)

            start_time = time.time()

            info = seq.frame_info(frame_num)
            info['previous_output'] = prev_output

            out = tracker.track(image, info)
            prev_output = OrderedDict(out)
            _store_outputs(out, {'time': time.time() - start_time})

            segmentation = out['segmentation'] if 'segmentation' in out else None
            if self.visdom is not None:
                tracker.visdom_draw_tracking(image, out['target_bbox'], segmentation)
            elif tracker.params.visualization:
                self.visualize(image, out['target_bbox'], segmentation)

        for key in ['target_bbox', 'segmentation']:
            if key in output and len(output[key]) <= 1:
                output.pop(key)

        return output

    def run_video(self, videofilepath, optional_box=None, debug=None, visdom_info=None, save_results=False):
        """Run the tracker with the vieofile.
        args:
            debug: Debug level.
        """

        params = self.get_parameters()

        debug_ = debug
        if debug is None:
            debug_ = getattr(params, 'debug', 0)
        params.debug = debug_

        params.tracker_name = self.name
        params.param_name = self.parameter_name
        self._init_visdom(visdom_info, debug_)

        multiobj_mode = getattr(params, 'multiobj_mode', getattr(self.tracker_class, 'multiobj_mode', 'default'))

        if multiobj_mode == 'default':
            tracker = self.create_tracker(params)
            if hasattr(tracker, 'initialize_features'):
                tracker.initialize_features()

        elif multiobj_mode == 'parallel':
            tracker = MultiObjectWrapper(self.tracker_class, params, self.visdom, fast_load=True)
        else:
            raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))

        assert os.path.isfile(videofilepath), "Invalid param {}".format(videofilepath)
        ", videofilepath must be a valid videofile"

        output_boxes = []

        cap = cv.VideoCapture(videofilepath)
        display_name = 'Display: ' + tracker.params.tracker_name
        cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
        cv.resizeWindow(display_name, 960, 720)
        success, frame = cap.read()
        cv.imshow(display_name, frame)

        def _build_init_info(box):
            return {'init_bbox': OrderedDict({1: box}), 'init_object_ids': [1, ], 'object_ids': [1, ],
                    'sequence_object_ids': [1, ]}

        if success is not True:
            print("Read frame from {} failed.".format(videofilepath))
            exit(-1)
        if optional_box is not None:
            assert isinstance(optional_box, (list, tuple))
            assert len(optional_box) == 4, "valid box's foramt is [x,y,w,h]"
            tracker.initialize(frame, _build_init_info(optional_box))
            output_boxes.append(optional_box)
        else:
            while True:
                # cv.waitKey()
                frame_disp = frame.copy()

                cv.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL,
                           1.5, (0, 0, 0), 1)

                x, y, w, h = cv.selectROI(display_name, frame_disp, fromCenter=False)
                init_state = [x, y, w, h]
                tracker.initialize(frame, _build_init_info(init_state))
                output_boxes.append(init_state)
                break

        while True:
            ret, frame = cap.read()

            if frame is None:
                break

            frame_disp = frame.copy()

            # Draw box
            out = tracker.track(frame)
            state = [int(s) for s in out['target_bbox'][1]]
            output_boxes.append(state)

            cv.rectangle(frame_disp, (state[0], state[1]), (state[2] + state[0], state[3] + state[1]),
                         (0, 255, 0), 5)

            font_color = (0, 0, 0)
            cv.putText(frame_disp, 'Tracking!', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
                       font_color, 1)
            cv.putText(frame_disp, 'Press r to reset', (20, 55), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
                       font_color, 1)
            cv.putText(frame_disp, 'Press q to quit', (20, 80), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
                       font_color, 1)

            # Display the resulting frame
            cv.imshow(display_name, frame_disp)
            key = cv.waitKey(1)
            if key == ord('q'):
                break
            elif key == ord('r'):
                ret, frame = cap.read()
                frame_disp = frame.copy()

                cv.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1.5,
                           (0, 0, 0), 1)

                cv.imshow(display_name, frame_disp)
                x, y, w, h = cv.selectROI(display_name, frame_disp, fromCenter=False)
                init_state = [x, y, w, h]
                tracker.initialize(frame, _build_init_info(init_state))
                output_boxes.append(init_state)

        # When everything done, release the capture
        cap.release()
        cv.destroyAllWindows()

        if save_results:
            if not os.path.exists(self.results_dir):
                os.makedirs(self.results_dir)
            video_name = Path(videofilepath).stem
            base_results_path = os.path.join(self.results_dir, 'video_{}'.format(video_name))

            tracked_bb = np.array(output_boxes).astype(int)
            bbox_file = '{}.txt'.format(base_results_path)
            np.savetxt(bbox_file, tracked_bb, delimiter='\t', fmt='%d')

    def run_webcam(self, debug=None, visdom_info=None):
        """Run the tracker with the webcam.
        args:
            debug: Debug level.
        """

        params = self.get_parameters()

        debug_ = debug
        if debug is None:
            debug_ = getattr(params, 'debug', 0)
        params.debug = debug_

        params.tracker_name = self.name
        params.param_name = self.parameter_name

        self._init_visdom(visdom_info, debug_)

        multiobj_mode = getattr(params, 'multiobj_mode', getattr(self.tracker_class, 'multiobj_mode', 'default'))

        if multiobj_mode == 'default':
            tracker = self.create_tracker(params)
        elif multiobj_mode == 'parallel':
            tracker = MultiObjectWrapper(self.tracker_class, params, self.visdom, fast_load=True)
        else:
            raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))

        class UIControl:
            def __init__(self):
                self.mode = 'init'  # init, select, track
                self.target_tl = (-1, -1)
                self.target_br = (-1, -1)
                self.new_init = False

            def mouse_callback(self, event, x, y, flags, param):
                if event == cv.EVENT_LBUTTONDOWN and self.mode == 'init':
                    self.target_tl = (x, y)
                    self.target_br = (x, y)
                    self.mode = 'select'
                elif event == cv.EVENT_MOUSEMOVE and self.mode == 'select':
                    self.target_br = (x, y)
                elif event == cv.EVENT_LBUTTONDOWN and self.mode == 'select':
                    self.target_br = (x, y)
                    self.mode = 'init'
                    self.new_init = True

            def get_tl(self):
                return self.target_tl if self.target_tl[0] < self.target_br[0] else self.target_br

            def get_br(self):
                return self.target_br if self.target_tl[0] < self.target_br[0] else self.target_tl

            def get_bb(self):
                tl = self.get_tl()
                br = self.get_br()

                bb = [min(tl[0], br[0]), min(tl[1], br[1]), abs(br[0] - tl[0]), abs(br[1] - tl[1])]
                return bb

        ui_control = UIControl()
        cap = cv.VideoCapture(0)
        display_name = 'Display: ' + self.name
        cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
        cv.resizeWindow(display_name, 960, 720)
        cv.setMouseCallback(display_name, ui_control.mouse_callback)

        next_object_id = 1
        sequence_object_ids = []
        prev_output = OrderedDict()
        while True:
            # Capture frame-by-frame
            ret, frame = cap.read()
            frame_disp = frame.copy()

            info = OrderedDict()
            info['previous_output'] = prev_output

            if ui_control.new_init:
                ui_control.new_init = False
                init_state = ui_control.get_bb()

                info['init_object_ids'] = [next_object_id, ]
                info['init_bbox'] = OrderedDict({next_object_id: init_state})
                sequence_object_ids.append(next_object_id)

                next_object_id += 1

            # Draw box
            if ui_control.mode == 'select':
                cv.rectangle(frame_disp, ui_control.get_tl(), ui_control.get_br(), (255, 0, 0), 2)

            if len(sequence_object_ids) > 0:
                info['sequence_object_ids'] = sequence_object_ids
                out = tracker.track(frame, info)
                prev_output = OrderedDict(out)

                if 'segmentation' in out:
                    frame_disp = overlay_mask(frame_disp, out['segmentation'])

                if 'target_bbox' in out:
                    for obj_id, state in out['target_bbox'].items():
                        state = [int(s) for s in state]
                        cv.rectangle(frame_disp, (state[0], state[1]), (state[2] + state[0], state[3] + state[1]),
                                     _tracker_disp_colors[obj_id], 5)

            # Put text
            font_color = (0, 0, 0)
            cv.putText(frame_disp, 'Select target', (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
            cv.putText(frame_disp, 'Press r to reset', (20, 55), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
                       font_color, 1)
            cv.putText(frame_disp, 'Press q to quit', (20, 85), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
                       font_color, 1)

            # Display the resulting frame
            cv.imshow(display_name, frame_disp)
            key = cv.waitKey(1)
            if key == ord('q'):
                break
            elif key == ord('r'):
                next_object_id = 1
                sequence_object_ids = []
                prev_output = OrderedDict()

                info = OrderedDict()

                info['object_ids'] = []
                info['init_object_ids'] = []
                info['init_bbox'] = OrderedDict()
                tracker.initialize(frame, info)
                ui_control.mode = 'init'

        # When everything done, release the capture
        cap.release()
        cv.destroyAllWindows()

    def run_vot2020(self, debug=None, visdom_info=None):
        params = self.get_parameters()
        params.tracker_name = self.name
        params.param_name = self.parameter_name
        params.run_id = self.run_id

        debug_ = debug
        if debug is None:
            debug_ = getattr(params, 'debug', 0)

        if debug is None:
            visualization_ = getattr(params, 'visualization', False)
        else:
            visualization_ = True if debug else False

        params.visualization = visualization_
        params.debug = debug_

        self._init_visdom(visdom_info, debug_)

        tracker = self.create_tracker(params)
        tracker.initialize_features()

        output_segmentation = tracker.predicts_segmentation_mask()

        import pytracking.evaluation.vot2020 as vot

        def _convert_anno_to_list(vot_anno):
            vot_anno = [vot_anno[0], vot_anno[1], vot_anno[2], vot_anno[3]]
            return vot_anno

        def _convert_image_path(image_path):
            return image_path

        """Run tracker on VOT."""

        if output_segmentation:
            handle = vot.VOT("mask")
        else:
            handle = vot.VOT("rectangle")

        vot_anno = handle.region()

        image_path = handle.frame()
        if not image_path:
            return
        image_path = _convert_image_path(image_path)

        image = self._read_image(image_path)

        if output_segmentation:
            vot_anno_mask = vot.make_full_size(vot_anno, (image.shape[1], image.shape[0]))
            bbox = masks_to_bboxes(torch.from_numpy(vot_anno_mask), fmt='t').squeeze().tolist()
        else:
            bbox = _convert_anno_to_list(vot_anno)
            vot_anno_mask = None

        out = tracker.initialize(image, {'init_mask': vot_anno_mask, 'init_bbox': bbox})

        if out is None:
            out = {}
        prev_output = OrderedDict(out)

        # Track
        while True:
            image_path = handle.frame()
            if not image_path:
                break
            image_path = _convert_image_path(image_path)

            image = self._read_image(image_path)

            info = OrderedDict()
            info['previous_output'] = prev_output

            out = tracker.track(image, info)
            prev_output = OrderedDict(out)

            if output_segmentation:
                pred = out['segmentation'].astype(np.uint8)
            else:
                state = out['target_bbox']
                pred = vot.Rectangle(*state)
            handle.report(pred, 1.0)

            segmentation = out['segmentation'] if 'segmentation' in out else None
            if self.visdom is not None:
                tracker.visdom_draw_tracking(image, out['target_bbox'], segmentation)
            elif tracker.params.visualization:
                self.visualize(image, out['target_bbox'], segmentation)


    def run_vot(self, debug=None, visdom_info=None):
        params = self.get_parameters()
        params.tracker_name = self.name
        params.param_name = self.parameter_name
        params.run_id = self.run_id

        debug_ = debug
        if debug is None:
            debug_ = getattr(params, 'debug', 0)

        if debug is None:
            visualization_ = getattr(params, 'visualization', False)
        else:
            visualization_ = True if debug else False

        params.visualization = visualization_
        params.debug = debug_

        self._init_visdom(visdom_info, debug_)

        tracker = self.create_tracker(params)
        tracker.initialize_features()

        import pytracking.evaluation.vot as vot

        def _convert_anno_to_list(vot_anno):
            vot_anno = [vot_anno[0][0][0], vot_anno[0][0][1], vot_anno[0][1][0], vot_anno[0][1][1],
                        vot_anno[0][2][0], vot_anno[0][2][1], vot_anno[0][3][0], vot_anno[0][3][1]]
            return vot_anno

        def _convert_image_path(image_path):
            image_path_new = image_path[20:- 2]
            return "".join(image_path_new)

        """Run tracker on VOT."""

        handle = vot.VOT("polygon")

        vot_anno_polygon = handle.region()
        vot_anno_polygon = _convert_anno_to_list(vot_anno_polygon)

        init_state = convert_vot_anno_to_rect(vot_anno_polygon, tracker.params.vot_anno_conversion_type)

        image_path = handle.frame()
        if not image_path:
            return
        image_path = _convert_image_path(image_path)

        image = self._read_image(image_path)
        tracker.initialize(image, {'init_bbox': init_state})

        # Track
        while True:
            image_path = handle.frame()
            if not image_path:
                break
            image_path = _convert_image_path(image_path)

            image = self._read_image(image_path)
            out = tracker.track(image)
            state = out['target_bbox']

            handle.report(vot.Rectangle(state[0], state[1], state[2], state[3]))

            segmentation = out['segmentation'] if 'segmentation' in out else None
            if self.visdom is not None:
                tracker.visdom_draw_tracking(image, out['target_bbox'], segmentation)
            elif tracker.params.visualization:
                self.visualize(image, out['target_bbox'], segmentation)

    def get_parameters(self):
        """Get parameters."""
        param_module = importlib.import_module('pytracking.parameter.{}.{}'.format(self.name, self.parameter_name))
        params = param_module.parameters()
        return params


    def init_visualization(self):
        self.pause_mode = False
        self.fig, self.ax = plt.subplots(1)
        self.fig.canvas.mpl_connect('key_press_event', self.press)
        plt.tight_layout()


    def visualize(self, image, state, segmentation=None):
        self.ax.cla()
        self.ax.imshow(image)
        if segmentation is not None:
            self.ax.imshow(segmentation, alpha=0.5)

        if isinstance(state, (OrderedDict, dict)):
            boxes = [v for k, v in state.items()]
        else:
            boxes = (state,)

        for i, box in enumerate(boxes, start=1):
            col = _tracker_disp_colors[i]
            col = [float(c) / 255.0 for c in col]
            rect = patches.Rectangle((box[0], box[1]), box[2], box[3], linewidth=1, edgecolor=col, facecolor='none')
            self.ax.add_patch(rect)

        if getattr(self, 'gt_state', None) is not None:
            gt_state = self.gt_state
            rect = patches.Rectangle((gt_state[0], gt_state[1]), gt_state[2], gt_state[3], linewidth=1, edgecolor='g', facecolor='none')
            self.ax.add_patch(rect)
        self.ax.set_axis_off()
        self.ax.axis('equal')
        draw_figure(self.fig)

        if self.pause_mode:
            keypress = False
            while not keypress:
                keypress = plt.waitforbuttonpress()

    def reset_tracker(self):
        pass

    def press(self, event):
        if event.key == 'p':
            self.pause_mode = not self.pause_mode
            print("Switching pause mode!")
        elif event.key == 'r':
            self.reset_tracker()
            print("Resetting target pos to gt!")

    def _read_image(self, image_file: str):
        im = cv.imread(image_file)
        return cv.cvtColor(im, cv.COLOR_BGR2RGB)
示例#2
0
class BaseTracker:
    """Base class for all trackers."""
    def visdom_ui_handler(self, data):
        if data['event_type'] == 'KeyPress':
            if data['key'] == ' ':
                self.pause_mode = not self.pause_mode

            elif data['key'] == 'ArrowRight' and self.pause_mode:
                self.step = True

    def __init__(self, params):
        self.params = params

        self.pause_mode = False
        self.step = False

        self.visdom = None
        if self.params.debug > 0 and self.params.visdom_info.get(
                'use_visdom', True):
            try:
                self.visdom = Visdom(self.params.debug, {
                    'handler': self.visdom_ui_handler,
                    'win_id': 'Tracking'
                },
                                     visdom_info=self.params.visdom_info)

                # Show help
                help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \
                            'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \
                            'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \
                            'block list.'
                self.visdom.register(help_text, 'text', 1, 'Help')
            except:
                time.sleep(0.5)
                print(
                    '!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n'
                    '!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!'
                )

    def initialize(self, image, info: dict) -> dict:
        """Overload this function in your tracker. This should initialize the model."""
        raise NotImplementedError

    def track(self, image) -> dict:
        """Overload this function in your tracker. This should track in the frame and update the model."""
        raise NotImplementedError

    def track_sequence(self, sequence):
        """Run tracker on a sequence."""

        output = {'target_bbox': [], 'time': []}

        def _store_outputs(tracker_out: dict, defaults=None):
            defaults = {} if defaults is None else defaults
            for key in tracker_out.keys():
                if key not in output:
                    raise RuntimeError('Unknown output from tracker.')
            for key in output.keys():
                val = tracker_out.get(key, defaults.get(key, None))
                if val is not None:
                    output[key].append(val)

        # Initialize
        image = self._read_image(sequence.frames[0])

        if self.params.visualization and self.visdom is None:
            self.init_visualization()
            self.visualize(image, sequence.get('init_bbox'), sequence.name, 0)

        start_time = time.time()
        out = self.initialize(image, sequence.init_info())
        if out is None:
            out = {}
        _store_outputs(
            out, {
                'target_bbox': sequence.get('init_bbox'),
                'time': time.time() - start_time
            })

        if self.visdom is not None:
            self.visdom.register((image, sequence.get('init_bbox')),
                                 'Tracking', 1, 'Tracking')

        # Track
        for i, frame in enumerate(sequence.frames[1:]):
            while True:
                if not self.pause_mode:
                    break
                elif self.step:
                    self.step = False
                    break
                else:
                    time.sleep(0.1)

            image = self._read_image(frame)

            start_time = time.time()
            out = self.track(image)
            _store_outputs(out, {'time': time.time() - start_time})

            if self.visdom is not None:
                self.visdom.register((image, out['target_bbox']), 'Tracking',
                                     1, 'Tracking')
            elif self.params.visualization:
                self.visualize(image, out['target_bbox'], sequence.name, i + 1)

        return output

    def track_videofile(self, videofilepath, optional_box=None):
        """Run track with a video file input."""

        assert os.path.isfile(videofilepath), "Invalid param {}".format(
            videofilepath)
        ", videofilepath must be a valid videofile"

        if hasattr(self, 'initialize_features'):
            self.initialize_features()

        cap = cv.VideoCapture(videofilepath)
        display_name = 'Display: ' + self.params.tracker_name
        cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
        cv.resizeWindow(display_name, 960, 720)
        success, frame = cap.read()
        cv.imshow(display_name, frame)
        if success is not True:
            print("Read frame from {} failed.".format(videofilepath))
            exit(-1)
        if optional_box is not None:
            assert isinstance(optional_box, list, tuple)
            assert len(optional_box) == 4, "valid box's foramt is [x,y,w,h]"
            self.initialize(frame, {'init_bbox': optional_box})
        else:
            while True:
                # cv.waitKey()
                frame_disp = frame.copy()

                cv.putText(frame_disp, 'Select target ROI and press ENTER',
                           (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1.5,
                           (0, 0, 0), 1)

                x, y, w, h = cv.selectROI(display_name,
                                          frame_disp,
                                          fromCenter=False)
                init_state = [x, y, w, h]
                self.initialize(frame, {'init_bbox': init_state})
                break

        while True:
            ret, frame = cap.read()

            if frame is None:
                return

            frame_disp = frame.copy()

            # Draw box
            out = self.track(frame)
            state = [int(s) for s in out['target_bbox']]
            cv.rectangle(frame_disp, (state[0], state[1]),
                         (state[2] + state[0], state[3] + state[1]),
                         (0, 255, 0), 5)

            font_color = (0, 0, 0)
            cv.putText(frame_disp, 'Tracking!', (20, 30),
                       cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
            cv.putText(frame_disp, 'Press r to reset', (20, 55),
                       cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
            cv.putText(frame_disp, 'Press q to quit', (20, 80),
                       cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)

            # Display the resulting frame
            cv.imshow(display_name, frame_disp)
            key = cv.waitKey(1)
            if key == ord('q'):
                break
            elif key == ord('r'):
                ret, frame = cap.read()
                frame_disp = frame.copy()

                cv.putText(frame_disp, 'Select target ROI and press ENTER',
                           (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1.5,
                           (0, 0, 0), 1)

                cv.imshow(display_name, frame_disp)
                x, y, w, h = cv.selectROI(display_name,
                                          frame_disp,
                                          fromCenter=False)
                init_state = [x, y, w, h]
                self.initialize(frame, {'init_bbox': init_state})

        # When everything done, release the capture
        cap.release()
        cv.destroyAllWindows()

    def track_webcam(self):
        """Run tracker with webcam."""
        class UIControl:
            def __init__(self):
                self.mode = 'init'  # init, select, track
                self.target_tl = (-1, -1)
                self.target_br = (-1, -1)
                self.mode_switch = False

            def mouse_callback(self, event, x, y, flags, param):
                if event == cv.EVENT_LBUTTONDOWN and self.mode == 'init':
                    self.target_tl = (x, y)
                    self.target_br = (x, y)
                    self.mode = 'select'
                    self.mode_switch = True
                elif event == cv.EVENT_MOUSEMOVE and self.mode == 'select':
                    self.target_br = (x, y)
                elif event == cv.EVENT_LBUTTONDOWN and self.mode == 'select':
                    self.target_br = (x, y)
                    self.mode = 'track'
                    self.mode_switch = True

            def get_tl(self):
                return self.target_tl if self.target_tl[0] < self.target_br[
                    0] else self.target_br

            def get_br(self):
                return self.target_br if self.target_tl[0] < self.target_br[
                    0] else self.target_tl

            def get_bb(self):
                tl = self.get_tl()
                br = self.get_br()

                bb = [
                    min(tl[0], br[0]),
                    min(tl[1], br[1]),
                    abs(br[0] - tl[0]),
                    abs(br[1] - tl[1])
                ]
                return bb

        ui_control = UIControl()
        cap = cv.VideoCapture(0)
        display_name = 'Display: ' + self.params.tracker_name
        cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
        cv.resizeWindow(display_name, 960, 720)
        cv.setMouseCallback(display_name, ui_control.mouse_callback)

        if hasattr(self, 'initialize_features'):
            self.initialize_features()

        while True:
            # Capture frame-by-frame
            ret, frame = cap.read()
            frame_disp = frame.copy()

            if ui_control.mode == 'track' and ui_control.mode_switch:
                ui_control.mode_switch = False
                init_state = ui_control.get_bb()
                self.initialize(frame, {'init_bbox': init_state})

            # Draw box
            if ui_control.mode == 'select':
                cv.rectangle(frame_disp, ui_control.get_tl(),
                             ui_control.get_br(), (255, 0, 0), 2)
            elif ui_control.mode == 'track':
                out = self.track(frame)
                state = [int(s) for s in out['target_bbox']]
                cv.rectangle(frame_disp, (state[0], state[1]),
                             (state[2] + state[0], state[3] + state[1]),
                             (0, 255, 0), 5)

            # Put text
            font_color = (0, 0, 0)
            if ui_control.mode == 'init' or ui_control.mode == 'select':
                cv.putText(frame_disp, 'Select target', (20, 30),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
                cv.putText(frame_disp, 'Press q to quit', (20, 55),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
            elif ui_control.mode == 'track':
                cv.putText(frame_disp, 'Tracking!', (20, 30),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
                cv.putText(frame_disp, 'Press r to reset', (20, 55),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
                cv.putText(frame_disp, 'Press q to quit', (20, 80),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
            # Display the resulting frame
            cv.imshow(display_name, frame_disp)
            key = cv.waitKey(1)
            if key == ord('q'):
                break
            elif key == ord('r'):
                ui_control.mode = 'init'

        # When everything done, release the capture
        cap.release()
        cv.destroyAllWindows()

    def track_vot(self):
        """Run tracker on VOT."""
        def _convert_anno_to_list(vot_anno):
            vot_anno = [
                vot_anno[0][0][0], vot_anno[0][0][1], vot_anno[0][1][0],
                vot_anno[0][1][1], vot_anno[0][2][0], vot_anno[0][2][1],
                vot_anno[0][3][0], vot_anno[0][3][1]
            ]
            return vot_anno

        def _convert_image_path(image_path):
            image_path_new = image_path[20:-2]
            return "".join(image_path_new)

        handle = vot.VOT("polygon")

        vot_anno_polygon = handle.region()
        vot_anno_polygon = _convert_anno_to_list(vot_anno_polygon)

        init_state = convert_vot_anno_to_rect(
            vot_anno_polygon, self.params.vot_anno_conversion_type)

        image_path = handle.frame()
        if not image_path:
            return
        image_path = _convert_image_path(image_path)

        image = self._read_image(image_path)
        self.initialize(image, {'init_bbox': init_state})

        if self.visdom is not None:
            self.visdom.register((image, init_state), 'Tracking', 1,
                                 'Tracking')

        # Track
        while True:
            while True:
                if not self.pause_mode:
                    break
                elif self.step:
                    self.step = False
                    break
                else:
                    time.sleep(0.1)

            image_path = handle.frame()
            if not image_path:
                break
            image_path = _convert_image_path(image_path)

            image = self._read_image(image_path)
            out = self.track(image)
            state = out['target_bbox']

            if self.visdom is not None:
                self.visdom.register((image, state), 'Tracking', 1, 'Tracking')
            handle.report(vot.Rectangle(state[0], state[1], state[2],
                                        state[3]))

    def reset_tracker(self):
        pass

    def press(self, event):
        if event.key == 'p':
            self.pause_mode = not self.pause_mode
            print("Switching pause mode!")
        elif event.key == 'r':
            self.reset_tracker()
            print("Resetting target pos to gt!")

    def init_visualization(self):
        # plt.ion()
        self.pause_mode = False
        self.fig, self.ax = plt.subplots(1)
        self.fig.canvas.mpl_connect('key_press_event', self.press)
        plt.tight_layout()

    def visualize(self, image, state, seq_name, frame_no):
        self.ax.cla()
        self.ax.imshow(image)
        rect = patches.Rectangle((state[0], state[1]),
                                 state[2],
                                 state[3],
                                 linewidth=1,
                                 edgecolor='r',
                                 facecolor='none')
        self.ax.add_patch(rect)

        if hasattr(self, 'gt_state') and False:
            gt_state = self.gt_state
            rect = patches.Rectangle((gt_state[0], gt_state[1]),
                                     gt_state[2],
                                     gt_state[3],
                                     linewidth=1,
                                     edgecolor='g',
                                     facecolor='none')
            self.ax.add_patch(rect)
        self.ax.set_axis_off()
        self.ax.axis('equal')
        #draw_figure(self.fig)
        if not isdir(seq_name):
            os.makedirs(seq_name)
        self.fig.savefig(
            join(seq_name, 'frame' + str(frame_no).zfill(5) + '.png'))

        if self.pause_mode:
            keypress = False
            while not keypress:
                keypress = plt.waitforbuttonpress()

    def show_image(self, im, plot_name=None, ax=None):
        if isinstance(im, torch.Tensor):
            im = torch_to_numpy(im)
        # plot_id = sum([ord(x) for x in list(plot_name)])

        if ax is None:
            plot_fig_name = 'debug_fig_' + plot_name
            plot_ax_name = 'debug_ax_' + plot_name
            if not hasattr(self, plot_fig_name):
                fig, ax = plt.subplots(1)
                setattr(self, plot_fig_name, fig)
                setattr(self, plot_ax_name, ax)
                plt.tight_layout()
                ax.set_title(plot_name)
            else:
                fig = getattr(self, plot_fig_name, None)
                ax = getattr(self, plot_ax_name, None)

        ax.cla()
        ax.imshow(im)

        ax.set_axis_off()
        ax.axis('equal')
        ax.set_title(plot_name)
        draw_figure(fig)

    def _read_image(self, image_file: str):
        return cv.cvtColor(cv.imread(image_file), cv.COLOR_BGR2RGB)
示例#3
0
class BaseTracker:
    """Base class for all trackers."""
    def visdom_ui_handler(self, data):
        if data['event_type'] == 'KeyPress':
            if data['key'] == ' ':
                self.pause_mode = not self.pause_mode

            elif data['key'] == 'ArrowRight' and self.pause_mode:
                self.step = True

    def __init__(self, params):
        self.params = params

        self.pause_mode = False
        self.step = False

        self.visdom = None
        if self.params.debug > 0 and self.params.visdom_info.get(
                'use_visdom', True):
            try:
                self.visdom = Visdom(self.params.debug, {
                    'handler': self.visdom_ui_handler,
                    'win_id': 'Tracking'
                },
                                     visdom_info=self.params.visdom_info)

                # Show help
                help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \
                            'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \
                            'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \
                            'block list.'
                self.visdom.register(help_text, 'text', 1, 'Help')
            except:
                time.sleep(0.5)
                print(
                    '!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n'
                    '!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!'
                )

    def initialize(self, image, info: dict) -> dict:
        """Overload this function in your tracker. This should initialize the model."""
        raise NotImplementedError

    def track(self, image) -> dict:
        """Overload this function in your tracker. This should track in the frame and update the model."""
        raise NotImplementedError

    def track_sequence(self, sequence):
        """Run tracker on a sequence."""

        output = {'target_bbox': [], 'time': [], 'scores': []}

        def _store_outputs(tracker_out: dict, defaults=None):
            defaults = {} if defaults is None else defaults
            for key in tracker_out.keys():
                if key not in output:
                    raise RuntimeError('Unknown output from tracker.')
            for key in output.keys():
                val = tracker_out.get(key, defaults.get(key, None))
                if val is not None:
                    output[key].append(val)

        # Initialize
        image = self._read_image(sequence.frames[0])
        #['image'] (480, 640, 3)
        if hasattr(self.params, 'use_depth_channel'):
            if self.params.use_depth_channel:
                print('have %d depth frames' % (len(sequence.depth_frames)))
                depth = self._read_depth(sequence.depth_frames[0])

                #depth = depth/1000.0 #from mm to m
                depth = np.repeat(np.expand_dims(depth, axis=2), 3,
                                  axis=2)  #.astype(np.uint8)
                if image.shape[0] != depth.shape[0]:
                    depth = depth[1:image.shape[0] + 1, :, :]
                #image = np.concatenate((image,np.expand_dims(depth,axis=2)),axis=2).astype(np.uint8)
                #print(['image'],image.shape)
                #['image'] (480, 640, 4)
                #print(['depth', depth.shape, np.mean(depth, (0,1)), np.std(depth, (0,1))])
                #['depth', (480, 640), 22.48,19.60]

        if self.params.visualization and self.visdom is None:
            self.init_visualization()
            self.visualize(image[:, :, 0:3], sequence.get('init_bbox'))

        start_time = time.time()

        if hasattr(self.params, 'use_depth_channel'):
            out = self.initialize(image, depth, sequence.init_info())
        else:
            out = self.initialize(image, sequence.init_info())

        if out is None:
            out = {}
        _store_outputs(
            out, {
                'target_bbox': sequence.get('init_bbox'),
                'time': time.time() - start_time,
                'scores': 1.0
            })

        if self.visdom is not None:
            self.visdom.register((image, sequence.get('init_bbox')),
                                 'Tracking', 1, 'Tracking')

        # Track
        ind_frame = 0
        for frame in sequence.frames[1:]:
            ind_frame = ind_frame + 1
            self.ind_frame = ind_frame
            while True:
                if not self.pause_mode:
                    break
                elif self.step:
                    self.step = False
                    break
                else:
                    time.sleep(0.1)

            image = self._read_image(frame)

            if hasattr(self.params, 'use_depth_channel'):
                #print(['depth image',sequence.depth_frames[ind_frame]])
                depth = self._read_depth(sequence.depth_frames[ind_frame])
                #depth = depth/1000.0 #from mm to m
                depth = np.repeat(np.expand_dims(depth, axis=2), 3,
                                  axis=2)  #.astype(np.uint8)

            start_time = time.time()
            if hasattr(self.params, 'use_depth_channel'):
                out = self.track(image, depth)
            else:
                out = self.track(image)
            _store_outputs(
                out, {
                    'time': time.time() - start_time,
                    'scores': self.debug_info['max_score']
                })

            #get gt_state if the gt_state for the whole sequence is provided
            if sequence.ground_truth_rect.shape[0] > 1:
                self.gt_state = sequence.ground_truth_rect[ind_frame]

            if self.visdom is not None:
                self.visdom.register((image, out['target_bbox']), 'Tracking',
                                     1, 'Tracking')
            elif self.params.visualization:
                # if hasattr(self.params, 'use_depth_channel'):
                #     self.visualize(image, out['target_bbox'], out_rgb['target_bbox'], out_depth['target_bbox'])
                # else:
                self.visualize(image, out['target_bbox'])

                #visualize the depth
                if hasattr(self.params, 'use_depth_channel'):
                    if os.path.exists(sequence.depth_frames[ind_frame]):
                        #dimage=self._read_image(sequence.depth_frames[ind_frame])
                        self.visualize_depth(
                            np.uint8(255 * depth / np.max(depth)),
                            out['target_bbox'])
                        #print(depth.shape)
                        pass

        return output

    def track_videofile(self, videofilepath, optional_box=None):
        """Run track with a video file input."""

        assert os.path.isfile(videofilepath), "Invalid param {}".format(
            videofilepath)
        ", videofilepath must be a valid videofile"

        if hasattr(self, 'initialize_features'):
            self.initialize_features()

        cap = cv.VideoCapture(videofilepath)
        display_name = 'Display: ' + self.params.tracker_name
        cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
        cv.resizeWindow(display_name, 960, 720)
        success, frame = cap.read()
        cv.imshow(display_name, frame)
        if success is not True:
            print("Read frame from {} failed.".format(videofilepath))
            exit(-1)
        if optional_box is not None:
            assert isinstance(optional_box, list, tuple)
            assert len(optional_box) == 4, "valid box's foramt is [x,y,w,h]"
            self.initialize(frame, {'init_bbox': optional_box})
        else:
            while True:
                # cv.waitKey()
                frame_disp = frame.copy()

                cv.putText(frame_disp, 'Select target ROI and press ENTER',
                           (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1.5,
                           (0, 0, 0), 1)

                x, y, w, h = cv.selectROI(display_name,
                                          frame_disp,
                                          fromCenter=False)
                init_state = [x, y, w, h]
                self.initialize(frame, {'init_bbox': init_state})
                break

        while True:
            ret, frame = cap.read()

            if frame is None:
                return

            frame_disp = frame.copy()

            # Draw box
            out = self.track(frame)
            state = [int(s) for s in out['target_bbox']]
            cv.rectangle(frame_disp, (state[0], state[1]),
                         (state[2] + state[0], state[3] + state[1]),
                         (0, 255, 0), 5)

            font_color = (0, 0, 0)
            cv.putText(frame_disp, 'Tracking!', (20, 30),
                       cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
            cv.putText(frame_disp, 'Press r to reset', (20, 55),
                       cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
            cv.putText(frame_disp, 'Press q to quit', (20, 80),
                       cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)

            # Display the resulting frame
            cv.imshow(display_name, frame_disp)
            key = cv.waitKey(1)
            if key == ord('q'):
                break
            elif key == ord('r'):
                ret, frame = cap.read()
                frame_disp = frame.copy()

                cv.putText(frame_disp, 'Select target ROI and press ENTER',
                           (20, 30), cv.FONT_HERSHEY_COMPLEX_SMALL, 1.5,
                           (0, 0, 0), 1)

                cv.imshow(display_name, frame_disp)
                x, y, w, h = cv.selectROI(display_name,
                                          frame_disp,
                                          fromCenter=False)
                init_state = [x, y, w, h]
                self.initialize(frame, {'init_bbox': init_state})

        # When everything done, release the capture
        cap.release()
        cv.destroyAllWindows()

    def track_webcam(self):
        """Run tracker with webcam."""
        class UIControl:
            def __init__(self):
                self.mode = 'init'  # init, select, track
                self.target_tl = (-1, -1)
                self.target_br = (-1, -1)
                self.mode_switch = False

            def mouse_callback(self, event, x, y, flags, param):
                if event == cv.EVENT_LBUTTONDOWN and self.mode == 'init':
                    self.target_tl = (x, y)
                    self.target_br = (x, y)
                    self.mode = 'select'
                    self.mode_switch = True
                elif event == cv.EVENT_MOUSEMOVE and self.mode == 'select':
                    self.target_br = (x, y)
                elif event == cv.EVENT_LBUTTONDOWN and self.mode == 'select':
                    self.target_br = (x, y)
                    self.mode = 'track'
                    self.mode_switch = True

            def get_tl(self):
                return self.target_tl if self.target_tl[0] < self.target_br[
                    0] else self.target_br

            def get_br(self):
                return self.target_br if self.target_tl[0] < self.target_br[
                    0] else self.target_tl

            def get_bb(self):
                tl = self.get_tl()
                br = self.get_br()

                bb = [
                    min(tl[0], br[0]),
                    min(tl[1], br[1]),
                    abs(br[0] - tl[0]),
                    abs(br[1] - tl[1])
                ]
                return bb

        ui_control = UIControl()
        cap = cv.VideoCapture(0)
        display_name = 'Display: ' + self.params.tracker_name
        cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
        cv.resizeWindow(display_name, 960, 720)
        cv.setMouseCallback(display_name, ui_control.mouse_callback)

        if hasattr(self, 'initialize_features'):
            self.initialize_features()

        while True:
            # Capture frame-by-frame
            ret, frame = cap.read()
            frame_disp = frame.copy()

            if ui_control.mode == 'track' and ui_control.mode_switch:
                ui_control.mode_switch = False
                init_state = ui_control.get_bb()
                self.initialize(frame, {'init_bbox': init_state})

            # Draw box
            if ui_control.mode == 'select':
                cv.rectangle(frame_disp, ui_control.get_tl(),
                             ui_control.get_br(), (255, 0, 0), 2)
            elif ui_control.mode == 'track':
                out = self.track(frame)
                state = [int(s) for s in out['target_bbox']]
                cv.rectangle(frame_disp, (state[0], state[1]),
                             (state[2] + state[0], state[3] + state[1]),
                             (0, 255, 0), 5)

            # Put text
            font_color = (0, 0, 0)
            if ui_control.mode == 'init' or ui_control.mode == 'select':
                cv.putText(frame_disp, 'Select target', (20, 30),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
                cv.putText(frame_disp, 'Press q to quit', (20, 55),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
            elif ui_control.mode == 'track':
                cv.putText(frame_disp, 'Tracking!', (20, 30),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
                cv.putText(frame_disp, 'Press r to reset', (20, 55),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
                cv.putText(frame_disp, 'Press q to quit', (20, 80),
                           cv.FONT_HERSHEY_COMPLEX_SMALL, 1, font_color, 1)
            # Display the resulting frame
            cv.imshow(display_name, frame_disp)
            key = cv.waitKey(1)
            if key == ord('q'):
                break
            elif key == ord('r'):
                ui_control.mode = 'init'

        # When everything done, release the capture
        cap.release()
        cv.destroyAllWindows()

    def track_vot(self):
        """Run tracker on VOT."""
        def _convert_anno_to_list(vot_anno):
            vot_anno = [
                vot_anno[0][0][0], vot_anno[0][0][1], vot_anno[0][1][0],
                vot_anno[0][1][1], vot_anno[0][2][0], vot_anno[0][2][1],
                vot_anno[0][3][0], vot_anno[0][3][1]
            ]
            return vot_anno

        def _convert_image_path(image_path):
            image_path_new = image_path[20:-2]
            return "".join(image_path_new)

        handle = vot.VOT("polygon")

        vot_anno_polygon = handle.region()
        vot_anno_polygon = _convert_anno_to_list(vot_anno_polygon)

        init_state = convert_vot_anno_to_rect(
            vot_anno_polygon, self.params.vot_anno_conversion_type)

        image_path = handle.frame()
        if not image_path:
            return
        image_path = _convert_image_path(image_path)

        image = self._read_image(image_path)
        self.initialize(image, {'init_bbox': init_state})

        if self.visdom is not None:
            self.visdom.register((image, init_state), 'Tracking', 1,
                                 'Tracking')

        # Track
        while True:
            while True:
                if not self.pause_mode:
                    break
                elif self.step:
                    self.step = False
                    break
                else:
                    time.sleep(0.1)

            image_path = handle.frame()
            if not image_path:
                break
            image_path = _convert_image_path(image_path)

            image = self._read_image(image_path)
            out = self.track(image)
            state = out['target_bbox']

            if self.visdom is not None:
                self.visdom.register((image, state), 'Tracking', 1, 'Tracking')
            handle.report(vot.Rectangle(state[0], state[1], state[2],
                                        state[3]))

    def reset_tracker(self):
        pass

    def press(self, event):
        if event.key == 'p':
            self.pause_mode = not self.pause_mode
            print("Switching pause mode!")
        elif event.key == 'r':
            self.reset_tracker()
            print("Resetting target pos to gt!")

    def init_visualization(self):
        # plt.ion()
        self.pause_mode = False
        self.fig, self.ax = plt.subplots(1)
        self.fig2, self.ax2 = plt.subplots(1)
        self.fig.canvas.mpl_connect('key_press_event', self.press)
        plt.tight_layout()

    def visualize(self, image, state, *var):
        self.ax.cla()
        self.ax.imshow(image)

        if (state[2] != 0 and state[3] != 0):
            self.ax.text(10,
                         30,
                         'FOUND',
                         fontsize=14,
                         bbox=dict(facecolor='green', alpha=0.2))
        else:
            self.ax.text(10,
                         30,
                         'NOT FOUND',
                         fontsize=14,
                         bbox=dict(facecolor='red', alpha=0.2))
            pass
        if len(var) == 0:
            rect = patches.Rectangle((state[0], state[1]),
                                     state[2],
                                     state[3],
                                     linewidth=1,
                                     edgecolor='r',
                                     facecolor='none')
            self.ax.add_patch(rect)

        if len(var) > 0:  #state_rgb, state_depth, provided
            state_rgb = var[0]
            state_depth = var[1]
            #draw one dot for the center of state_rgb

            # self.ax.plot(state_rgb[0]+state_rgb[2]/2, state_rgb[1]+state_rgb[3]/2,'ro')
            # self.ax.plot(state_depth[0]+state_depth[2]/2, state_depth[1]+state_depth[3]/2, 'bo')
            # self.ax.plot(state[0]+state[2]/2, state[1]+state[3]/2,'wo')

            #another dot for the center of state_depth
            # rect_rgb= patches.Rectangle((state_rgb[0], state_rgb[1]), state_rgb[2], state_rgb[3], linewidth=2, edgecolor='r', facecolor='none')
            # self.ax.add_patch(rect_rgb)
            rect_depth = patches.Rectangle((state_depth[0], state_depth[1]),
                                           state_depth[2],
                                           state_depth[3],
                                           linewidth=2,
                                           edgecolor='b',
                                           facecolor='none')
            self.ax.add_patch(rect_depth)
            rect = patches.Rectangle((state[0], state[1]),
                                     state[2],
                                     state[3],
                                     linewidth=1,
                                     edgecolor='w',
                                     facecolor='none')
            self.ax.add_patch(rect)
        #print(['var', var])
        #['var', (tensor([263.5000, 266.5000]), tensor([263.5000, 266.5000]), tensor([263.6045, 271.1568]))]

        if hasattr(self, 'gt_state') and True:
            gt_state = self.gt_state
            self.ax.plot(gt_state[0] + gt_state[2] / 2,
                         gt_state[1] + gt_state[3] / 2, 'go')
            rect = patches.Rectangle((gt_state[0], gt_state[1]),
                                     gt_state[2],
                                     gt_state[3],
                                     linewidth=2,
                                     edgecolor='g',
                                     facecolor='none')
            self.ax.add_patch(rect)

        self.ax.set_axis_off()
        self.ax.axis('equal')
        draw_figure(self.fig)

        if hasattr(self, 'ind_frame'):
            if os.path.exists('./tracking_results/imgs'):
                self.fig.savefig('./tracking_results/imgs/img_%d.png' %
                                 self.ind_frame)

        if self.pause_mode:
            keypress = False
            while not keypress:
                keypress = plt.waitforbuttonpress()

    def visualize_depth(self, image, state):
        self.ax2.cla()
        self.ax2.imshow(image)

        self.ax2.set_axis_off()
        self.ax2.axis('equal')
        plt.draw()
        plt.pause(0.001)

        if hasattr(self, 'ind_frame'):
            if os.path.exists('./tracking_results/imgs'):
                self.fig2.savefig('./tracking_results/imgs/depth_%d.png' %
                                  self.ind_frame)

        if self.pause_mode:
            plt.waitforbuttonpress()

    def show_image(self, im, plot_name=None, ax=None):
        if isinstance(im, torch.Tensor):
            im = torch_to_numpy(im)
        # plot_id = sum([ord(x) for x in list(plot_name)])

        if ax is None:
            plot_fig_name = 'debug_fig_' + plot_name
            plot_ax_name = 'debug_ax_' + plot_name
            if not hasattr(self, plot_fig_name):
                fig, ax = plt.subplots(1)
                setattr(self, plot_fig_name, fig)
                setattr(self, plot_ax_name, ax)
                plt.tight_layout()
                ax.set_title(plot_name)
            else:
                fig = getattr(self, plot_fig_name, None)
                ax = getattr(self, plot_ax_name, None)

        ax.cla()
        ax.imshow(im)

        ax.set_axis_off()
        ax.axis('equal')
        ax.set_title(plot_name)
        draw_figure(fig)

    def _read_depth(self, image_file: str):

        # Full kernels
        FULL_KERNEL_3 = np.ones((3, 3), np.uint8)
        FULL_KERNEL_5 = np.ones((5, 5), np.uint8)
        FULL_KERNEL_7 = np.ones((7, 7), np.uint8)
        FULL_KERNEL_9 = np.ones((9, 9), np.uint8)
        FULL_KERNEL_31 = np.ones((31, 31), np.uint8)

        depth = cv.imread(image_file, cv.COLOR_BGR2GRAY)
        #print(['_read_depth', depth.min(), depth.max(), depth.mean(), depth.std()])
        if 'Princeton' in image_file:  #depth.max()>=60000: # bug found, we need to bitshift depth.
            depth = np.bitwise_or(np.right_shift(depth, 3),
                                  np.left_shift(depth, 13))
        depth = depth / 1000.0
        depth[depth >= 8.0] = 8.0
        depth[depth <= 0.0] = 8.0
        #depth=8.0-depth
        # Hole closing
        depth = cv.morphologyEx(depth, cv.MORPH_CLOSE, FULL_KERNEL_7)
        #depth = 255.0*depth/(np.max(depth)+1e-3)
        return depth

    def _read_image(self, image_file: str):
        return cv.cvtColor(cv.imread(image_file), cv.COLOR_BGR2RGB)
示例#4
0
class TrackerYolo:
    """Wraps the tracker for evaluation and running purposes.
    args:
        name: Name of tracking method.
        parameter_name: Name of parameter file.
        run_id: The run id.
        display_name: Name to be displayed in the result plots.
    """

    def __init__(self, name: str, parameter_name: str, run_id: int = None, display_name: str = None):
        assert run_id is None or isinstance(run_id, int)

        self.name = name
        self.parameter_name = parameter_name
        self.run_id = run_id
        self.display_name = display_name

        env = env_settings()
        if self.run_id is None:
            self.results_dir = '{}/{}/{}'.format(env.results_path, self.name, self.parameter_name)
            self.segmentation_dir = '{}/{}/{}'.format(env.segmentation_path, self.name, self.parameter_name)
        else:
            self.results_dir = '{}/{}/{}_{:03d}'.format(env.results_path, self.name, self.parameter_name, self.run_id)
            self.segmentation_dir = '{}/{}/{}_{:03d}'.format(env.segmentation_path, self.name, self.parameter_name, self.run_id)

        tracker_module_abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tracker', self.name))
        if os.path.isdir(tracker_module_abspath):
            tracker_module = importlib.import_module('pytracking.tracker.{}'.format(self.name))
            self.tracker_class = tracker_module.get_tracker_class()
        else:
            self.tracker_class = None

        self.visdom = None

    def _init_visdom(self, visdom_info, debug):
        visdom_info = {} if visdom_info is None else visdom_info
        self.pause_mode = False
        self.step = False
        if debug > 0 and visdom_info.get('use_visdom', True):
            try:
                self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'},
                                     visdom_info=visdom_info)

                # Show help
                help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \
                            'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \
                            'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \
                            'block list.'
                self.visdom.register(help_text, 'text', 1, 'Help')
            except:
                time.sleep(0.5)
                print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n'
                      '!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!')

    def _visdom_ui_handler(self, data):
        if data['event_type'] == 'KeyPress':
            if data['key'] == ' ':
                self.pause_mode = not self.pause_mode

            elif data['key'] == 'ArrowRight' and self.pause_mode:
                self.step = True

    def create_tracker(self, params):
        tracker = self.tracker_class(params)
        tracker.visdom = self.visdom
        return tracker

    def run_video(self, videofilepath, optional_box=None, debug=None, visdom_info=None, save_results=False):
        """Run the tracker with the vieofile.
        args:
            debug: Debug level.
        """

        def yolo_search(W, H, frame_yolo):

            # if the frame dimensions are empty, grab them
            if W is None or H is None:
                (H, W) = frame_yolo.shape[:2]

            # construct a blob from the input frame and then perform a forward
            # pass of the YOLO object detector, giving us our bounding boxes
            # and associated probabilities
            blob = cv.dnn.blobFromImage(frame_yolo, 1 / 255.0, (416, 416), swapRB=True, crop=False)
            net.setInput(blob)
            layerOutputs = net.forward(ln)

            # initialize our lists of detected bounding boxes, confidences,
            # and class IDs, respectively
            boxes = []
            confidences = []
            classIDs = []

            # loop over each of the layer outputs
            for output in layerOutputs:
                # loop over each of the detections
                for detection in output:
                    # extract the class ID and confidence (i.e., probability)
                    # of the current object detection
                    scores = detection[5:]
                    classID = np.argmax(scores)
                    confidence = scores[classID]

                    # filter weak prediction and unrelated classes
                    if classID not in outdoor_classes and confidence > 0.5:
                        # scale the bounding box coordinates back relative to
                        # the size of the image, keeping in mind that YOLO
                        # actually returns the center (x, y)-coordinates of
                        # the bounding box followed by the boxes' width and
                        # height
                        box = detection[0:4] * np.array([W, H, W, H])
                        (centerX, centerY, width, height) = box.astype("int")

                        # use the center (x, y)-coordinates to derive the top
                        # and and left corner of the bounding box
                        x = int(centerX - (width / 2))
                        y = int(centerY - (height / 2))

                        # update our list of bounding box coordinates,
                        # confidences, and class IDs
                        boxes.append([x, y, int(width), int(height)])
                        confidences.append(float(confidence))
                        classIDs.append(classID)

            # apply non-maxima suppression to suppress weak, overlapping
            # bounding boxes
            idxs = cv.dnn.NMSBoxes(boxes, confidences, 0.5, 0.3)

            # ensure at least one detection exists
            if len(idxs) > 0:
                # loop over the indexes we are keeping
                for i in idxs.flatten():
                    # extract the bounding box coordinates
                    (x, y) = (boxes[i][0], boxes[i][1])
                    (w, h) = (boxes[i][2], boxes[i][3])

                    # draw a bounding box rectangle and label on the frame
                    color = [int(c) for c in COLORS[classIDs[i]]]
                    cv.rectangle(frame_yolo, (x, y), (x + w, y + h), color, 2)
                    text = "{}: {:.4f}".format(LABELS[classIDs[i]], confidences[i])
                    cv.putText(frame_yolo, text, (x, y - 5), cv.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
                    if classIDs[i] == 45:  # 0 - person, 65 - remote 45 - bowl
                        detection_flag = 1
                        tl_coor = (x, y)  # top left coordinates
                        br_coor = ((x + w), (y + h))  # bottom right coordinates
                        cv.rectangle(frame_yolo, tl_coor, br_coor, (255, 255, 255), 2)
                        return tl_coor, br_coor, detection_flag, frame_yolo
            return (0, 0), (0, 0), 0, frame_yolo

        # load the COCO class labels our YOLO model was trained on
        # and the classes that wont be used (coco.names contains the names)
        # Init a detection flag
        det_flag = 0
        stop_yolo = 0
        labelsPath = "/home/ebo/Parrot_CV_Project/local_yolo/yolo-coco/coco.names"
        LABELS = open(labelsPath).read().strip().split("\n")
        outdoor_classes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                           17, 18, 19, 20, 21, 22, 23,
                           29, 30, 31, 32, 33, 34, 35, 36, 37, 38]

        # initialize a list of colors to represent each possible class label
        np.random.seed(42)
        COLORS = np.random.randint(0, 255, size=(len(LABELS), 3), dtype="uint8")

        # derive the paths to the YOLO weights and model configuration
        weightsPath = "/home/ebo/Parrot_CV_Project/local_yolo/yolo-coco/yolov3.weights"
        configPath = "/home/ebo/Parrot_CV_Project/local_yolo/yolo-coco/yolov3.cfg"

        # load our YOLO object detector trained on COCO dataset (80 classes)
        print("[INFO] loading YOLO from disk...")
        net = cv.dnn.readNetFromDarknet(configPath, weightsPath)

        # determine only the output layers from yolo
        ln = net.getLayerNames()
        ln = [ln[i[0] - 1] for i in net.getUnconnectedOutLayers()]

        (W, H) = (None, None)

        params = self.get_parameters()

        debug_ = debug
        if debug is None:
            debug_ = getattr(params, 'debug', 0)
        params.debug = debug_

        params.tracker_name = self.name
        params.param_name = self.parameter_name
        self._init_visdom(visdom_info, debug_)

        multiobj_mode = getattr(params, 'multiobj_mode', getattr(self.tracker_class, 'multiobj_mode', 'default'))

        if multiobj_mode == 'default':
            tracker = self.create_tracker(params)
            if hasattr(tracker, 'initialize_features'):
                tracker.initialize_features()

        elif multiobj_mode == 'parallel':
            tracker = MultiObjectWrapper(self.tracker_class, params, self.visdom, fast_load=True)
        else:
            raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))

        assert os.path.isfile(videofilepath), "Invalid param {}".format(videofilepath)
        ", videofilepath must be a valid videofile"

        output_boxes = []

        cap = cv.VideoCapture(videofilepath)
        display_name = 'Display: ' + tracker.params.tracker_name
        cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
        cv.resizeWindow(display_name, 960, 720)
        success, frame = cap.read()
        cv.imshow(display_name, frame)

        def _build_init_info(box):
            return {'init_bbox': OrderedDict({1: box}), 'init_object_ids': [1, ], 'object_ids': [1, ],
                    'sequence_object_ids': [1, ]}

        if success is not True:
            print("Read frame from {} failed.".format(videofilepath))
            exit(-1)

        while True:
            ret, frame = cap.read()

            if frame is None:
                break

            frame_disp = frame.copy()

            if W is None or H is None:
                (H, W) = frame_disp.shape[:2]

            if stop_yolo == 0:
                tl_yolo, br_yolo, det_flag, frame_disp = yolo_search(W, H, frame.copy())
                cv.putText(frame_disp, "Searching: BOWL", (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

            if det_flag == 1:
                if stop_yolo == 0:
                    stop_yolo = 1
                    x = tl_yolo[0]
                    y = tl_yolo[1]
                    w = abs(br_yolo[0] - tl_yolo[0])
                    h = abs(br_yolo[1] - tl_yolo[1])
                    init_state = [x, y, w, h]
                    tracker.initialize(frame, _build_init_info(init_state))
                    output_boxes.append(init_state)

                # Draw box
                out = tracker.track(frame)
                state = [int(s) for s in out['target_bbox'][1]]
                output_boxes.append(state)

                tl = (state[0], state[1])
                br = (state[2] + state[0], state[3] + state[1])
                w = state[2]
                h = state[3]
                cv.rectangle(frame_disp, tl, br, (0, 255, 0), 5)

                center = (int(tl[0] + w/2), int(tl[1] + h/2))
                cv.circle(frame_disp, center, 3, (0, 0, 255), -1)
                cv.putText(frame_disp, "FOUND BOWL", (50, 50), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
                if center[0] < W*0.40:
                    cv.putText(frame_disp, "MOVE LEFT", (50, 150), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
                elif center[0] > W*0.60:
                    cv.putText(frame_disp, "MOVE RIGHT", (450, 150), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
                if center[1] < H*0.40:
                    cv.putText(frame_disp, "MOVE UP", (200, 50), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
                elif center[1] > H*0.60:
                    cv.putText(frame_disp, "MOVE DOWN", (200, 300), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
                if w*h < W*H*0.05:
                    cv.putText(frame_disp, "MOVE FORWARD", (int(W/2), int(H/2)), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
                elif w*h > W*H*0.15:
                    cv.putText(frame_disp, "MOVE BACK", (int(W/2), int(H/2)), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

            # Display the resulting frame
            cv.imshow(display_name, frame_disp)
            key = cv.waitKey(1)

            if key == ord('q'):
                break

        # When everything done, release the capture
        cap.release()
        cv.destroyAllWindows()

        if save_results:
            if not os.path.exists(self.results_dir):
                os.makedirs(self.results_dir)
            video_name = Path(videofilepath).stem
            base_results_path = os.path.join(self.results_dir, 'video_{}'.format(video_name))

            tracked_bb = np.array(output_boxes).astype(int)
            bbox_file = '{}.txt'.format(base_results_path)
            np.savetxt(bbox_file, tracked_bb, delimiter='\t', fmt='%d')

    def run_webcam(self, debug=None, visdom_info=None):
        """Run the tracker with the webcam.
        args:
            debug: Debug level.
        """

        def yolo_search(W, H, frame_yolo):
            fl = 0
            # if the frame dimensions are empty, grab them
            if W is None or H is None:
                (H, W) = frame_yolo.shape[:2]

            # construct a blob from the input frame and then perform a forward
            # pass of the YOLO object detector, giving us our bounding boxes
            # and associated probabilities
            blob = cv.dnn.blobFromImage(frame_yolo, 1 / 255.0, (416, 416), swapRB=True, crop=False)
            net.setInput(blob)
            layerOutputs = net.forward(ln)

            # initialize our lists of detected bounding boxes, confidences,
            # and class IDs, respectively
            boxes = []
            confidences = []
            classIDs = []

            # loop over each of the layer outputs
            for output in layerOutputs:
                # loop over each of the detections
                for detection in output:
                    # extract the class ID and confidence (i.e., probability)
                    # of the current object detection
                    scores = detection[5:]
                    classID = np.argmax(scores)
                    confidence = scores[classID]

                    # filter weak prediction and unrelated classes
                    if classID not in outdoor_classes and confidence > 0.5:
                        # scale the bounding box coordinates back relative to
                        # the size of the image, keeping in mind that YOLO
                        # actually returns the center (x, y)-coordinates of
                        # the bounding box followed by the boxes' width and
                        # height
                        box = detection[0:4] * np.array([W, H, W, H])
                        (centerX, centerY, width, height) = box.astype("int")

                        # use the center (x, y)-coordinates to derive the top
                        # and and left corner of the bounding box
                        x = int(centerX - (width / 2))
                        y = int(centerY - (height / 2))

                        # update our list of bounding box coordinates,
                        # confidences, and class IDs
                        boxes.append([x, y, int(width), int(height)])
                        confidences.append(float(confidence))
                        classIDs.append(classID)

            # apply non-maxima suppression to suppress weak, overlapping
            # bounding boxes
            idxs = cv.dnn.NMSBoxes(boxes, confidences, 0.5, 0.3)

            # ensure at least one detection exists
            if len(idxs) > 0:
                # loop over the indexes we are keeping
                for i in idxs.flatten():
                    # extract the bounding box coordinates
                    (x, y) = (boxes[i][0], boxes[i][1])
                    (w, h) = (boxes[i][2], boxes[i][3])

                    # draw a bounding box rectangle and label on the frame
                    color = [int(c) for c in COLORS[classIDs[i]]]
                    cv.rectangle(frame_yolo, (x, y), (x + w, y + h), color, 2)
                    text = "{}: {:.4f}".format(LABELS[classIDs[i]], confidences[i])
                    cv.putText(frame_yolo, text, (x, y - 5), cv.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
                    if classIDs[i] == 0:
                        detection_flag = 1
                        tl_coor = (x, y)  # top left coordinates
                        br_coor = ((x + w), (y + h))  # bottom right coordinates
                        coordinates_text = "{} {}".format(tl_coor, br_coor)
                        cv.rectangle(frame_yolo, tl_coor, br_coor, (255, 255, 255), 2)
                        fl = 1
                        return tl_coor, br_coor, detection_flag, frame_yolo
            if fl == 0:
                return (0, 0), (0, 0), 0, frame_yolo

        # load the COCO class labels our YOLO model was trained on
        # and the classes that wont be used (coco.names contains the names)
        # Init a detection flag
        det_flag = 0
        labelsPath = "/home/ebo/Parrot_CV_Project/local_yolo/yolo-coco/coco.names"
        LABELS = open(labelsPath).read().strip().split("\n")
        outdoor_classes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                           17, 18, 19, 20, 21, 22, 23,
                           29, 30, 31, 32, 33, 34, 35, 36, 37, 38]

        # initialize a list of colors to represent each possible class label
        np.random.seed(42)
        COLORS = np.random.randint(0, 255, size=(len(LABELS), 3), dtype="uint8")

        # derive the paths to the YOLO weights and model configuration
        weightsPath = "/home/ebo/Parrot_CV_Project/local_yolo/yolo-coco/yolov3.weights"
        configPath = "/home/ebo/Parrot_CV_Project/local_yolo/yolo-coco/yolov3.cfg"

        # load our YOLO object detector trained on COCO dataset (80 classes)
        print("[INFO] loading YOLO from disk...")
        net = cv.dnn.readNetFromDarknet(configPath, weightsPath)

        # determine only the output layers from yolo
        ln = net.getLayerNames()
        ln = [ln[i[0] - 1] for i in net.getUnconnectedOutLayers()]

        (W, H) = (None, None)
        temp_flag = 0

        params = self.get_parameters()

        debug_ = debug
        if debug is None:
            debug_ = getattr(params, 'debug', 0)
        params.debug = debug_

        params.tracker_name = self.name
        params.param_name = self.parameter_name

        self._init_visdom(visdom_info, debug_)

        multiobj_mode = getattr(params, 'multiobj_mode', getattr(self.tracker_class, 'multiobj_mode', 'default'))

        if multiobj_mode == 'default':
            tracker = self.create_tracker(params)
        elif multiobj_mode == 'parallel':
            tracker = MultiObjectWrapper(self.tracker_class, params, self.visdom, fast_load=True)
        else:
            raise ValueError('Unknown multi object mode {}'.format(multiobj_mode))

        class UIControl:
            def __init__(self):
                self.mode = 'init'  # init, select, track
                self.new_init = False

            def get_bb(self):
                # yolo bb
                if det_flag == 1:
                    tl = tl_yolo
                    br = br_yolo

                bb = [min(tl[0], br[0]), min(tl[1], br[1]), abs(br[0] - tl[0]), abs(br[1] - tl[1])]

                return bb

        ui_control = UIControl()
        cap = cv.VideoCapture(0)
        display_name = 'Display: ' + self.name
        cv.namedWindow(display_name, cv.WINDOW_NORMAL | cv.WINDOW_KEEPRATIO)
        cv.resizeWindow(display_name, 960, 720)

        next_object_id = 1
        sequence_object_ids = []
        prev_output = OrderedDict()
        while True:
            # Capture frame-by-frame
            ret, frame = cap.read()
            frame_disp = frame.copy()

            tl_yolo, br_yolo, det_flag, frame_yolo = yolo_search(W, H, frame.copy())

            info = OrderedDict()
            info['previous_output'] = prev_output

            # If there's a human detection, show it
            if det_flag == 1 and temp_flag == 0:
                init_state = ui_control.get_bb()
                info['init_object_ids'] = [next_object_id, ]
                info['init_bbox'] = OrderedDict({next_object_id: init_state})
                sequence_object_ids.append(next_object_id)
                next_object_id += 1
                temp_flag = 1

            if len(sequence_object_ids) > 0:
                info['sequence_object_ids'] = sequence_object_ids
                out = tracker.track(frame, info)
                prev_output = OrderedDict(out)

                if 'segmentation' in out:
                    frame_disp = overlay_mask(frame_disp, out['segmentation'])

                if 'target_bbox' in out:
                    for obj_id, state in out['target_bbox'].items():
                        state = [int(s) for s in state]
                        cv.rectangle(frame_disp, (state[0], state[1]), (state[2] + state[0], state[3] + state[1]),
                                     _tracker_disp_colors[obj_id], 5)

            # Put text
            font_color = (0, 0, 0)
            cv.putText(frame_disp, 'Press r to reset', (20, 25), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
                       font_color, 1)
            cv.putText(frame_disp, 'Press q to quit', (20, 55), cv.FONT_HERSHEY_COMPLEX_SMALL, 1,
                       font_color, 1)

            # Display the resulting frame
            cv.imshow(display_name, frame_disp)
            cv.imshow("YOLO", frame_yolo)
            key = cv.waitKey(1)
            if key == ord('q'):
                break
            elif key == ord('r'):
                next_object_id = 1
                sequence_object_ids = []
                prev_output = OrderedDict()

                info = OrderedDict()

                info['object_ids'] = []
                info['init_object_ids'] = []
                info['init_bbox'] = OrderedDict()
                tracker.initialize(frame, info)
                ui_control.mode = 'init'

        # When everything done, release the capture
        cap.release()
        cv.destroyAllWindows()

    def get_parameters(self):
        """Get parameters."""
        param_module = importlib.import_module('pytracking.parameter.{}.{}'.format(self.name, self.parameter_name))
        params = param_module.parameters()
        return params

    def init_visualization(self):
        self.pause_mode = False
        self.fig, self.ax = plt.subplots(1)
        self.fig.canvas.mpl_connect('key_press_event', self.press)
        plt.tight_layout()

    def visualize(self, image, state, segmentation=None):
        self.ax.cla()
        self.ax.imshow(image)
        if segmentation is not None:
            self.ax.imshow(segmentation, alpha=0.5)

        if isinstance(state, (OrderedDict, dict)):
            boxes = [v for k, v in state.items()]
        else:
            boxes = (state,)

        for i, box in enumerate(boxes, start=1):
            col = _tracker_disp_colors[i]
            col = [float(c) / 255.0 for c in col]
            rect = patches.Rectangle((box[0], box[1]), box[2], box[3], linewidth=1, edgecolor=col, facecolor='none')
            self.ax.add_patch(rect)

        if getattr(self, 'gt_state', None) is not None:
            gt_state = self.gt_state
            rect = patches.Rectangle((gt_state[0], gt_state[1]), gt_state[2], gt_state[3], linewidth=1, edgecolor='g', facecolor='none')
            self.ax.add_patch(rect)
        self.ax.set_axis_off()
        self.ax.axis('equal')
        draw_figure(self.fig)

        if self.pause_mode:
            keypress = False
            while not keypress:
                keypress = plt.waitforbuttonpress()

    def reset_tracker(self):
        pass

    def press(self, event):
        if event.key == 'p':
            self.pause_mode = not self.pause_mode
            print("Switching pause mode!")
        elif event.key == 'r':
            self.reset_tracker()
            print("Resetting target pos to gt!")

    def _read_image(self, image_file: str):
        im = cv.imread(image_file)
        return cv.cvtColor(im, cv.COLOR_BGR2RGB)