Beispiel #1
0
    def __init__(self,
                 prop_net,
                 fuse_net,
                 s2m_net,
                 images,
                 num_objects,
                 device='cuda:0'):
        self.s2m_net = s2m_net.to(device, non_blocking=True)

        images, self.pad = pad_divide_by(images, 16, images.shape[-2:])
        self.device = device

        # Padded dimensions
        nh, nw = images.shape[-2:]
        self.nh, self.nw = nh, nw

        # True dimensions
        t = images.shape[1]
        h, w = images.shape[-2:]

        self.k = num_objects
        self.t, self.h, self.w = t, h, w

        self.interacted_count = 0
        self.davis_schedule = [2, 5, 7]

        self.processor = InferenceCore(prop_net,
                                       fuse_net,
                                       images,
                                       num_objects,
                                       mem_profile=0,
                                       device=device)
Beispiel #2
0
    def __init__(self, prop_net, fuse_net, s2m_ctrl: S2MController,
                 fbrs_ctrl: FBRSController, images, masks, num_objects,
                 mem_freq, mem_profile):
        super().__init__()

        self.images = images
        self.masks = masks
        self.num_objects = num_objects
        self.s2m_controller = s2m_ctrl
        self.fbrs_controller = fbrs_ctrl
        self.processor = InferenceCore(prop_net,
                                       fuse_net,
                                       images_to_torch(images, device='cpu'),
                                       num_objects,
                                       mem_freq=mem_freq,
                                       mem_profile=mem_profile)

        self.num_frames, self.height, self.width = self.images.shape[:3]

        # IOU computation
        if self.masks is not None:
            self.ious = np.zeros(self.num_frames)
            self.iou_curve = []

        # set window
        self.setWindowTitle('MiVOS')
        self.setGeometry(100, 100, self.width, self.height + 100)

        # some buttons
        self.play_button = QPushButton('Play')
        self.play_button.clicked.connect(self.on_play)
        self.run_button = QPushButton('Propagate')
        self.run_button.clicked.connect(self.on_run)
        self.commit_button = QPushButton('Commit')
        self.commit_button.clicked.connect(self.on_commit)

        self.undo_button = QPushButton('Undo')
        self.undo_button.clicked.connect(self.on_undo)
        self.reset_button = QPushButton('Reset Frame')
        self.reset_button.clicked.connect(self.on_reset)
        self.save_button = QPushButton('Save')
        self.save_button.clicked.connect(self.save)

        # LCD
        self.lcd = QTextEdit()
        self.lcd.setReadOnly(True)
        self.lcd.setMaximumHeight(28)
        self.lcd.setMaximumWidth(120)
        self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames - 1))

        # timeline slider
        self.tl_slider = QSlider(Qt.Horizontal)
        self.tl_slider.valueChanged.connect(self.tl_slide)
        self.tl_slider.setMinimum(0)
        self.tl_slider.setMaximum(self.num_frames - 1)
        self.tl_slider.setValue(0)
        self.tl_slider.setTickPosition(QSlider.TicksBelow)
        self.tl_slider.setTickInterval(1)

        # brush size slider
        self.brush_label = QLabel()
        self.brush_label.setAlignment(Qt.AlignCenter)
        self.brush_label.setMinimumWidth(100)

        self.brush_slider = QSlider(Qt.Horizontal)
        self.brush_slider.valueChanged.connect(self.brush_slide)
        self.brush_slider.setMinimum(1)
        self.brush_slider.setMaximum(100)
        self.brush_slider.setValue(3)
        self.brush_slider.setTickPosition(QSlider.TicksBelow)
        self.brush_slider.setTickInterval(2)
        self.brush_slider.setMinimumWidth(300)

        # combobox
        self.combo = QComboBox(self)
        self.combo.addItem("davis")
        self.combo.addItem("fade")
        self.combo.addItem("light")
        self.combo.currentTextChanged.connect(self.set_viz_mode)

        # Radio buttons for type of interactions
        self.curr_interaction = 'Click'
        self.interaction_group = QButtonGroup()
        self.radio_fbrs = QRadioButton('Click')
        self.radio_s2m = QRadioButton('Scribble')
        self.radio_free = QRadioButton('Free')
        self.interaction_group.addButton(self.radio_fbrs)
        self.interaction_group.addButton(self.radio_s2m)
        self.interaction_group.addButton(self.radio_free)
        self.radio_fbrs.toggled.connect(self.interaction_radio_clicked)
        self.radio_s2m.toggled.connect(self.interaction_radio_clicked)
        self.radio_free.toggled.connect(self.interaction_radio_clicked)
        self.radio_fbrs.toggle()

        # Main canvas -> QLabel
        self.main_canvas = QLabel()
        self.main_canvas.setSizePolicy(QSizePolicy.Expanding,
                                       QSizePolicy.Expanding)
        self.main_canvas.setAlignment(Qt.AlignCenter)
        self.main_canvas.setMinimumSize(100, 100)

        self.main_canvas.mousePressEvent = self.on_press
        self.main_canvas.mouseMoveEvent = self.on_motion
        self.main_canvas.setMouseTracking(
            True)  # Required for all-time tracking
        self.main_canvas.mouseReleaseEvent = self.on_release

        # Minimap -> Also a QLbal
        self.minimap = QLabel()
        self.minimap.setSizePolicy(QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        self.minimap.setAlignment(Qt.AlignTop)
        self.minimap.setMinimumSize(100, 100)

        # Zoom-in buttons
        self.zoom_p_button = QPushButton('Zoom +')
        self.zoom_p_button.clicked.connect(self.on_zoom_plus)
        self.zoom_m_button = QPushButton('Zoom -')
        self.zoom_m_button.clicked.connect(self.on_zoom_minus)
        self.finish_local_button = QPushButton('Finish Local')
        self.finish_local_button.clicked.connect(self.on_finish_local)
        self.finish_local_button.setDisabled(True)

        # Console on the GUI
        self.console = QPlainTextEdit()
        self.console.setReadOnly(True)
        self.console.setMinimumHeight(100)
        self.console.setMaximumHeight(100)

        # progress bar
        self.progress = QProgressBar(self)
        self.progress.setGeometry(0, 0, 300, 25)
        self.progress.setMinimumWidth(300)
        self.progress.setMinimum(0)
        self.progress.setMaximum(100)
        self.progress.setFormat('Idle')
        self.progress.setStyleSheet("QProgressBar{color: black;}")
        self.progress.setAlignment(Qt.AlignCenter)

        # navigator
        navi = QHBoxLayout()
        navi.addWidget(self.lcd)
        navi.addWidget(self.play_button)

        interact_subbox = QVBoxLayout()
        interact_topbox = QHBoxLayout()
        interact_botbox = QHBoxLayout()
        interact_topbox.setAlignment(Qt.AlignCenter)
        interact_topbox.addWidget(self.radio_s2m)
        interact_topbox.addWidget(self.radio_fbrs)
        interact_topbox.addWidget(self.radio_free)
        interact_topbox.addWidget(self.brush_label)
        interact_botbox.addWidget(self.brush_slider)
        interact_subbox.addLayout(interact_topbox)
        interact_subbox.addLayout(interact_botbox)
        navi.addLayout(interact_subbox)

        navi.addStretch(1)
        navi.addWidget(self.undo_button)
        navi.addWidget(self.reset_button)

        navi.addStretch(1)
        navi.addWidget(self.progress)
        navi.addWidget(QLabel('Overlay Mode'))
        navi.addWidget(self.combo)
        navi.addStretch(1)
        navi.addWidget(self.commit_button)
        navi.addWidget(self.run_button)
        navi.addWidget(self.save_button)

        # Drawing area, main canvas and minimap
        draw_area = QHBoxLayout()
        draw_area.addWidget(self.main_canvas, 4)

        # Minimap area
        minimap_area = QVBoxLayout()
        minimap_area.setAlignment(Qt.AlignTop)
        mini_label = QLabel('Minimap')
        mini_label.setAlignment(Qt.AlignTop)
        minimap_area.addWidget(mini_label)
        # Minimap zooming
        minimap_ctrl = QHBoxLayout()
        minimap_ctrl.setAlignment(Qt.AlignTop)
        minimap_ctrl.addWidget(self.zoom_p_button)
        minimap_ctrl.addWidget(self.zoom_m_button)
        minimap_ctrl.addWidget(self.finish_local_button)
        minimap_area.addLayout(minimap_ctrl)
        minimap_area.addWidget(self.minimap)
        minimap_area.addWidget(QLabel('Overall procedure: '))
        minimap_area.addWidget(
            QLabel('1. Label a frame (all objects) with whatever means'))
        minimap_area.addWidget(QLabel('2. Propagate'))
        minimap_area.addWidget(
            QLabel(
                '3. Find a frame with error, correct it and proagatte again'))
        minimap_area.addWidget(QLabel('4. Repeat'))
        minimap_area.addWidget(QLabel('Tips: '))
        minimap_area.addWidget(
            QLabel(
                '1: Use Ctrl+Left-click to drag-select a local control region.'
            ))
        minimap_area.addWidget(QLabel('Click finish local to go back.'))
        minimap_area.addWidget(
            QLabel('2: Use Right-click to label background.'))
        minimap_area.addWidget(
            QLabel('3: Use Num-keys to change the object id. '))
        minimap_area.addWidget(QLabel('(1-Red, 2-Green, 3-Blue, ...)'))
        minimap_area.addWidget(
            QLabel('4: \"Commit\" only works for S2M, it clears the buffer.'))
        minimap_area.addWidget(self.console)

        draw_area.addLayout(minimap_area, 1)

        layout = QVBoxLayout()
        layout.addLayout(draw_area)
        layout.addWidget(self.tl_slider)
        layout.addLayout(navi)
        self.setLayout(layout)

        # timer
        self.timer = QTimer()
        self.timer.setSingleShot(False)
        self.timer.timeout.connect(self.on_time)

        # Local mode related states
        self.ctrl_key = False
        self.in_local_mode = False
        self.local_bb = None
        self.local_interactions = {}
        self.this_local_interactions = []
        self.local_interaction = None

        # initialize visualization
        self.viz_mode = 'davis'
        self.current_mask = np.zeros(
            (self.num_frames, self.height, self.width), dtype=np.uint8)
        self.vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
        self.vis_alpha = np.zeros((self.height, self.width, 1),
                                  dtype=np.float32)
        self.brush_vis_map = np.zeros((self.height, self.width, 3),
                                      dtype=np.uint8)
        self.brush_vis_alpha = np.zeros((self.height, self.width, 1),
                                        dtype=np.float32)
        self.vis_hist = deque(maxlen=100)
        self.cursur = 0
        self.on_showing = None

        # initialize local visualization (which is mostly unknown at this point)
        self.local_vis_map = None
        self.local_vis_alpha = None
        self.local_brush_vis_map = None
        self.local_brush_vis_alpha = None
        self.local_vis_hist = deque(maxlen=100)

        # Zoom parameters
        self.zoom_pixels = 150

        # initialize action
        self.interactions = {}
        self.interactions['interact'] = [[] for _ in range(self.num_frames)]
        self.interactions['annotated_frame'] = []
        self.this_frame_interactions = []
        self.interaction = None
        self.reset_this_interaction()
        self.pressed = False
        self.right_click = False
        self.ctrl_size = False
        self.current_object = 1
        self.last_ex = self.last_ey = 0

        # Objects shortcuts
        for i in range(1, num_objects + 1):
            QShortcut(QKeySequence(str(i)), self).activated.connect(
                functools.partial(self.hit_number_key, i))

        # <- and -> shortcuts
        QShortcut(QKeySequence(Qt.Key_Left),
                  self).activated.connect(self.on_prev)
        QShortcut(QKeySequence(Qt.Key_Right),
                  self).activated.connect(self.on_next)

        # Mask saving
        # QShortcut(QKeySequence('s'), self).activated.connect(self.save)
        # QShortcut(QKeySequence('l'), self).activated.connect(self.debug_pressed)

        self.interacted_mask = None

        self.show_current_frame()
        self.show()

        self.waiting_to_start = True
        self.global_timer = Timer().start()
        self.algo_timer = Timer()
        self.user_timer = Timer()
        self.console_push_text('Initialized.')
Beispiel #3
0
class App(QWidget):
    def __init__(self, prop_net, fuse_net, s2m_ctrl: S2MController,
                 fbrs_ctrl: FBRSController, images, masks, num_objects,
                 mem_freq, mem_profile):
        super().__init__()

        self.images = images
        self.masks = masks
        self.num_objects = num_objects
        self.s2m_controller = s2m_ctrl
        self.fbrs_controller = fbrs_ctrl
        self.processor = InferenceCore(prop_net,
                                       fuse_net,
                                       images_to_torch(images, device='cpu'),
                                       num_objects,
                                       mem_freq=mem_freq,
                                       mem_profile=mem_profile)

        self.num_frames, self.height, self.width = self.images.shape[:3]

        # IOU computation
        if self.masks is not None:
            self.ious = np.zeros(self.num_frames)
            self.iou_curve = []

        # set window
        self.setWindowTitle('MiVOS')
        self.setGeometry(100, 100, self.width, self.height + 100)

        # some buttons
        self.play_button = QPushButton('Play')
        self.play_button.clicked.connect(self.on_play)
        self.run_button = QPushButton('Propagate')
        self.run_button.clicked.connect(self.on_run)
        self.commit_button = QPushButton('Commit')
        self.commit_button.clicked.connect(self.on_commit)

        self.undo_button = QPushButton('Undo')
        self.undo_button.clicked.connect(self.on_undo)
        self.reset_button = QPushButton('Reset Frame')
        self.reset_button.clicked.connect(self.on_reset)
        self.save_button = QPushButton('Save')
        self.save_button.clicked.connect(self.save)

        # LCD
        self.lcd = QTextEdit()
        self.lcd.setReadOnly(True)
        self.lcd.setMaximumHeight(28)
        self.lcd.setMaximumWidth(120)
        self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames - 1))

        # timeline slider
        self.tl_slider = QSlider(Qt.Horizontal)
        self.tl_slider.valueChanged.connect(self.tl_slide)
        self.tl_slider.setMinimum(0)
        self.tl_slider.setMaximum(self.num_frames - 1)
        self.tl_slider.setValue(0)
        self.tl_slider.setTickPosition(QSlider.TicksBelow)
        self.tl_slider.setTickInterval(1)

        # brush size slider
        self.brush_label = QLabel()
        self.brush_label.setAlignment(Qt.AlignCenter)
        self.brush_label.setMinimumWidth(100)

        self.brush_slider = QSlider(Qt.Horizontal)
        self.brush_slider.valueChanged.connect(self.brush_slide)
        self.brush_slider.setMinimum(1)
        self.brush_slider.setMaximum(100)
        self.brush_slider.setValue(3)
        self.brush_slider.setTickPosition(QSlider.TicksBelow)
        self.brush_slider.setTickInterval(2)
        self.brush_slider.setMinimumWidth(300)

        # combobox
        self.combo = QComboBox(self)
        self.combo.addItem("davis")
        self.combo.addItem("fade")
        self.combo.addItem("light")
        self.combo.currentTextChanged.connect(self.set_viz_mode)

        # Radio buttons for type of interactions
        self.curr_interaction = 'Click'
        self.interaction_group = QButtonGroup()
        self.radio_fbrs = QRadioButton('Click')
        self.radio_s2m = QRadioButton('Scribble')
        self.radio_free = QRadioButton('Free')
        self.interaction_group.addButton(self.radio_fbrs)
        self.interaction_group.addButton(self.radio_s2m)
        self.interaction_group.addButton(self.radio_free)
        self.radio_fbrs.toggled.connect(self.interaction_radio_clicked)
        self.radio_s2m.toggled.connect(self.interaction_radio_clicked)
        self.radio_free.toggled.connect(self.interaction_radio_clicked)
        self.radio_fbrs.toggle()

        # Main canvas -> QLabel
        self.main_canvas = QLabel()
        self.main_canvas.setSizePolicy(QSizePolicy.Expanding,
                                       QSizePolicy.Expanding)
        self.main_canvas.setAlignment(Qt.AlignCenter)
        self.main_canvas.setMinimumSize(100, 100)

        self.main_canvas.mousePressEvent = self.on_press
        self.main_canvas.mouseMoveEvent = self.on_motion
        self.main_canvas.setMouseTracking(
            True)  # Required for all-time tracking
        self.main_canvas.mouseReleaseEvent = self.on_release

        # Minimap -> Also a QLbal
        self.minimap = QLabel()
        self.minimap.setSizePolicy(QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        self.minimap.setAlignment(Qt.AlignTop)
        self.minimap.setMinimumSize(100, 100)

        # Zoom-in buttons
        self.zoom_p_button = QPushButton('Zoom +')
        self.zoom_p_button.clicked.connect(self.on_zoom_plus)
        self.zoom_m_button = QPushButton('Zoom -')
        self.zoom_m_button.clicked.connect(self.on_zoom_minus)
        self.finish_local_button = QPushButton('Finish Local')
        self.finish_local_button.clicked.connect(self.on_finish_local)
        self.finish_local_button.setDisabled(True)

        # Console on the GUI
        self.console = QPlainTextEdit()
        self.console.setReadOnly(True)
        self.console.setMinimumHeight(100)
        self.console.setMaximumHeight(100)

        # progress bar
        self.progress = QProgressBar(self)
        self.progress.setGeometry(0, 0, 300, 25)
        self.progress.setMinimumWidth(300)
        self.progress.setMinimum(0)
        self.progress.setMaximum(100)
        self.progress.setFormat('Idle')
        self.progress.setStyleSheet("QProgressBar{color: black;}")
        self.progress.setAlignment(Qt.AlignCenter)

        # navigator
        navi = QHBoxLayout()
        navi.addWidget(self.lcd)
        navi.addWidget(self.play_button)

        interact_subbox = QVBoxLayout()
        interact_topbox = QHBoxLayout()
        interact_botbox = QHBoxLayout()
        interact_topbox.setAlignment(Qt.AlignCenter)
        interact_topbox.addWidget(self.radio_s2m)
        interact_topbox.addWidget(self.radio_fbrs)
        interact_topbox.addWidget(self.radio_free)
        interact_topbox.addWidget(self.brush_label)
        interact_botbox.addWidget(self.brush_slider)
        interact_subbox.addLayout(interact_topbox)
        interact_subbox.addLayout(interact_botbox)
        navi.addLayout(interact_subbox)

        navi.addStretch(1)
        navi.addWidget(self.undo_button)
        navi.addWidget(self.reset_button)

        navi.addStretch(1)
        navi.addWidget(self.progress)
        navi.addWidget(QLabel('Overlay Mode'))
        navi.addWidget(self.combo)
        navi.addStretch(1)
        navi.addWidget(self.commit_button)
        navi.addWidget(self.run_button)
        navi.addWidget(self.save_button)

        # Drawing area, main canvas and minimap
        draw_area = QHBoxLayout()
        draw_area.addWidget(self.main_canvas, 4)

        # Minimap area
        minimap_area = QVBoxLayout()
        minimap_area.setAlignment(Qt.AlignTop)
        mini_label = QLabel('Minimap')
        mini_label.setAlignment(Qt.AlignTop)
        minimap_area.addWidget(mini_label)
        # Minimap zooming
        minimap_ctrl = QHBoxLayout()
        minimap_ctrl.setAlignment(Qt.AlignTop)
        minimap_ctrl.addWidget(self.zoom_p_button)
        minimap_ctrl.addWidget(self.zoom_m_button)
        minimap_ctrl.addWidget(self.finish_local_button)
        minimap_area.addLayout(minimap_ctrl)
        minimap_area.addWidget(self.minimap)
        minimap_area.addWidget(QLabel('Overall procedure: '))
        minimap_area.addWidget(
            QLabel('1. Label a frame (all objects) with whatever means'))
        minimap_area.addWidget(QLabel('2. Propagate'))
        minimap_area.addWidget(
            QLabel(
                '3. Find a frame with error, correct it and proagatte again'))
        minimap_area.addWidget(QLabel('4. Repeat'))
        minimap_area.addWidget(QLabel('Tips: '))
        minimap_area.addWidget(
            QLabel(
                '1: Use Ctrl+Left-click to drag-select a local control region.'
            ))
        minimap_area.addWidget(QLabel('Click finish local to go back.'))
        minimap_area.addWidget(
            QLabel('2: Use Right-click to label background.'))
        minimap_area.addWidget(
            QLabel('3: Use Num-keys to change the object id. '))
        minimap_area.addWidget(QLabel('(1-Red, 2-Green, 3-Blue, ...)'))
        minimap_area.addWidget(
            QLabel('4: \"Commit\" only works for S2M, it clears the buffer.'))
        minimap_area.addWidget(self.console)

        draw_area.addLayout(minimap_area, 1)

        layout = QVBoxLayout()
        layout.addLayout(draw_area)
        layout.addWidget(self.tl_slider)
        layout.addLayout(navi)
        self.setLayout(layout)

        # timer
        self.timer = QTimer()
        self.timer.setSingleShot(False)
        self.timer.timeout.connect(self.on_time)

        # Local mode related states
        self.ctrl_key = False
        self.in_local_mode = False
        self.local_bb = None
        self.local_interactions = {}
        self.this_local_interactions = []
        self.local_interaction = None

        # initialize visualization
        self.viz_mode = 'davis'
        self.current_mask = np.zeros(
            (self.num_frames, self.height, self.width), dtype=np.uint8)
        self.vis_map = np.zeros((self.height, self.width, 3), dtype=np.uint8)
        self.vis_alpha = np.zeros((self.height, self.width, 1),
                                  dtype=np.float32)
        self.brush_vis_map = np.zeros((self.height, self.width, 3),
                                      dtype=np.uint8)
        self.brush_vis_alpha = np.zeros((self.height, self.width, 1),
                                        dtype=np.float32)
        self.vis_hist = deque(maxlen=100)
        self.cursur = 0
        self.on_showing = None

        # initialize local visualization (which is mostly unknown at this point)
        self.local_vis_map = None
        self.local_vis_alpha = None
        self.local_brush_vis_map = None
        self.local_brush_vis_alpha = None
        self.local_vis_hist = deque(maxlen=100)

        # Zoom parameters
        self.zoom_pixels = 150

        # initialize action
        self.interactions = {}
        self.interactions['interact'] = [[] for _ in range(self.num_frames)]
        self.interactions['annotated_frame'] = []
        self.this_frame_interactions = []
        self.interaction = None
        self.reset_this_interaction()
        self.pressed = False
        self.right_click = False
        self.ctrl_size = False
        self.current_object = 1
        self.last_ex = self.last_ey = 0

        # Objects shortcuts
        for i in range(1, num_objects + 1):
            QShortcut(QKeySequence(str(i)), self).activated.connect(
                functools.partial(self.hit_number_key, i))

        # <- and -> shortcuts
        QShortcut(QKeySequence(Qt.Key_Left),
                  self).activated.connect(self.on_prev)
        QShortcut(QKeySequence(Qt.Key_Right),
                  self).activated.connect(self.on_next)

        # Mask saving
        # QShortcut(QKeySequence('s'), self).activated.connect(self.save)
        # QShortcut(QKeySequence('l'), self).activated.connect(self.debug_pressed)

        self.interacted_mask = None

        self.show_current_frame()
        self.show()

        self.waiting_to_start = True
        self.global_timer = Timer().start()
        self.algo_timer = Timer()
        self.user_timer = Timer()
        self.console_push_text('Initialized.')

    def resizeEvent(self, event):
        self.show_current_frame()

    def save(self):
        folder_path = str(
            QFileDialog.getExistingDirectory(self, "Select Save Directory"))

        self.console_push_text('Saving masks and overlays...')
        mask_dir = path.join(folder_path, 'mask')
        overlay_dir = path.join(folder_path, 'overlay')

        os.makedirs(mask_dir, exist_ok=True)
        os.makedirs(overlay_dir, exist_ok=True)

        for i in range(self.num_frames):
            # Save mask
            mask = Image.fromarray(self.current_mask[i]).convert('P')
            mask.putpalette(palette)
            mask.save(os.path.join(mask_dir, '{:05d}.png'.format(i)))

            # Save overlay
            overlay = overlay_davis(self.images[i], self.current_mask[i])
            overlay = Image.fromarray(overlay)
            overlay.save(os.path.join(overlay_dir, '{:05d}.png'.format(i)))
        self.console_push_text('Done.')

    def console_push_text(self, text):
        text = '[A: %s, U: %s]: %s' % (self.algo_timer.format(),
                                       self.user_timer.format(), text)
        self.console.appendPlainText(text)
        self.console.moveCursor(QTextCursor.End)
        print(text)

    def interaction_radio_clicked(self, event):
        self.last_interaction = self.curr_interaction
        if self.radio_s2m.isChecked():
            self.curr_interaction = 'Scribble'
            self.brush_size = 3
            self.brush_slider.setDisabled(True)
        elif self.radio_fbrs.isChecked():
            self.curr_interaction = 'Click'
            self.brush_size = 3
            self.brush_slider.setDisabled(True)
        elif self.radio_free.isChecked():
            self.brush_slider.setDisabled(False)
            self.brush_slide()
            self.curr_interaction = 'Free'
        if self.curr_interaction == 'Scribble':
            self.commit_button.setEnabled(True)
        else:
            self.commit_button.setEnabled(False)

        # if self.last_interaction != self.curr_interaction:
        # self.console_push_text('Interaction changed to ' + self.curr_interaction + '.')

    def compose_current_im(self):
        if self.in_local_mode:
            if self.viz_mode == 'fade':
                self.viz = overlay_davis_fade(self.local_np_im,
                                              self.local_np_mask)
            elif self.viz_mode == 'davis':
                self.viz = overlay_davis(self.local_np_im, self.local_np_mask)
            elif self.viz_mode == 'light':
                self.viz = overlay_davis(self.local_np_im, self.local_np_mask,
                                         0.9)
            else:
                raise NotImplementedError
        else:
            if self.viz_mode == 'fade':
                self.viz = overlay_davis_fade(self.images[self.cursur],
                                              self.current_mask[self.cursur])
            elif self.viz_mode == 'davis':
                self.viz = overlay_davis(self.images[self.cursur],
                                         self.current_mask[self.cursur])
            elif self.viz_mode == 'light':
                self.viz = overlay_davis(self.images[self.cursur],
                                         self.current_mask[self.cursur], 0.9)
            else:
                raise NotImplementedError

    def update_interact_vis(self):
        # Update the interactions without re-computing the overlay
        height, width, channel = self.viz.shape
        bytesPerLine = 3 * width

        if self.in_local_mode:
            vis_map = self.local_vis_map
            vis_alpha = self.local_vis_alpha
            brush_vis_map = self.local_brush_vis_map
            brush_vis_alpha = self.local_brush_vis_alpha
        else:
            vis_map = self.vis_map
            vis_alpha = self.vis_alpha
            brush_vis_map = self.brush_vis_map
            brush_vis_alpha = self.brush_vis_alpha

        self.viz_with_stroke = self.viz * (1 - vis_alpha) + vis_map * vis_alpha
        self.viz_with_stroke = self.viz_with_stroke * (
            1 - brush_vis_alpha) + brush_vis_map * brush_vis_alpha
        self.viz_with_stroke = self.viz_with_stroke.astype(np.uint8)

        qImg = QImage(self.viz_with_stroke.data, width, height, bytesPerLine,
                      QImage.Format_RGB888)
        self.main_canvas.setPixmap(
            QPixmap(
                qImg.scaled(self.main_canvas.size(), Qt.KeepAspectRatio,
                            Qt.FastTransformation)))

        self.main_canvas_size = self.main_canvas.size()
        self.image_size = qImg.size()

    def update_minimap(self):
        # Limit it within the valid range
        if self.in_local_mode:
            if self.minimap_in_local_drawn:
                # Do not redraw
                return
            self.minimap_in_local_drawn = True
            patch = self.minimap_in_local.astype(np.uint8)
        else:
            ex, ey = self.last_ex, self.last_ey
            r = self.zoom_pixels // 2
            ex = int(round(max(r, min(self.width - r, ex))))
            ey = int(round(max(r, min(self.height - r, ey))))

            patch = self.viz_with_stroke[ey - r:ey + r,
                                         ex - r:ex + r, :].astype(np.uint8)

        height, width, channel = patch.shape
        bytesPerLine = 3 * width
        qImg = QImage(patch.data, width, height, bytesPerLine,
                      QImage.Format_RGB888)
        self.minimap.setPixmap(
            QPixmap(
                qImg.scaled(self.minimap.size(), Qt.KeepAspectRatio,
                            Qt.FastTransformation)))

    def show_current_frame(self):
        # Re-compute overlay and show the image
        self.compose_current_im()
        self.update_interact_vis()
        self.update_minimap()
        self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur,
                                                  self.num_frames - 1))
        self.tl_slider.setValue(self.cursur)

    def get_scaled_pos(self, x, y):
        # Un-scale and un-pad the label coordinates into image coordinates
        oh, ow = self.image_size.height(), self.image_size.width()
        nh, nw = self.main_canvas_size.height(), self.main_canvas_size.width()

        h_ratio = nh / oh
        w_ratio = nw / ow
        dominate_ratio = min(h_ratio, w_ratio)

        # Solve scale
        x /= dominate_ratio
        y /= dominate_ratio

        # Solve padding
        fh, fw = nh / dominate_ratio, nw / dominate_ratio
        x -= (fw - ow) / 2
        y -= (fh - oh) / 2

        if self.in_local_mode:
            x = max(0, min(self.local_width - 1, x))
            y = max(0, min(self.local_height - 1, y))
        else:
            x = max(0, min(self.width - 1, x))
            y = max(0, min(self.height - 1, y))

        # return int(round(x)), int(round(y))
        return x, y

    def clear_visualization(self):
        if self.in_local_mode:
            self.local_vis_map.fill(0)
            self.local_vis_alpha.fill(0)
            self.local_vis_hist.clear()
            self.local_vis_hist.append(
                (self.local_vis_map.copy(), self.local_vis_alpha.copy()))
        else:
            self.vis_map.fill(0)
            self.vis_alpha.fill(0)
            self.vis_hist.clear()
            self.vis_hist.append((self.vis_map.copy(), self.vis_alpha.copy()))

    def reset_this_interaction(self):
        self.complete_interaction()
        self.clear_visualization()
        if self.in_local_mode:
            self.local_interaction = None
            self.local_interactions['interact'] = self.local_interactions[
                'interact'][:1]
        else:
            self.interaction = None
            self.this_frame_interactions = []
        self.undo_button.setDisabled(True)
        if self.fbrs_controller is not None:
            self.fbrs_controller.unanchor()

    def set_viz_mode(self):
        self.viz_mode = self.combo.currentText()
        self.show_current_frame()

    def tl_slide(self):
        if self.waiting_to_start:
            self.waiting_to_start = False
            self.algo_timer.start()
            self.user_timer.start()
            self.console_push_text('Timers started.')

        self.reset_this_interaction()
        self.cursur = self.tl_slider.value()
        self.show_current_frame()

    def brush_slide(self):
        self.brush_size = self.brush_slider.value()
        self.brush_label.setText('Brush size: %d' % self.brush_size)
        try:
            if type(self.interaction) == FreeInteraction:
                self.interaction.set_size(self.brush_size)
        except AttributeError:
            # Initialization, forget about it
            pass

    def progress_step_cb(self):
        self.progress_num += 1
        ratio = self.progress_num / self.progress_max
        self.progress.setValue(int(ratio * 100))
        self.progress.setFormat('%2.1f%%' % (ratio * 100))
        QApplication.processEvents()

    def progress_total_cb(self, total):
        self.progress_max = total
        self.progress_num = -1
        self.progress_step_cb()

    def on_run(self):
        self.user_timer.pause()
        if self.interacted_mask is None:
            self.console_push_text('Cannot propagate! No interacted mask!')
            return

        self.console_push_text('Propagation started.')
        # self.interacted_mask = torch.softmax(self.interacted_mask*1000, dim=0)
        self.current_mask = self.processor.interact(self.interacted_mask,
                                                    self.cursur,
                                                    self.progress_total_cb,
                                                    self.progress_step_cb)
        self.interacted_mask = None
        # clear scribble and reset
        self.show_current_frame()
        self.reset_this_interaction()
        self.progress.setFormat('Idle')
        self.progress.setValue(0)
        self.console_push_text('Propagation finished!')
        self.user_timer.start()

    def on_commit(self):
        self.complete_interaction()
        self.update_interacted_mask()

    def on_prev(self):
        # self.tl_slide will trigger on setValue
        self.cursur = max(0, self.cursur - 1)
        self.tl_slider.setValue(self.cursur)

    def on_next(self):
        # self.tl_slide will trigger on setValue
        self.cursur = min(self.cursur + 1, self.num_frames - 1)
        self.tl_slider.setValue(self.cursur)

    def on_time(self):
        self.cursur += 1
        if self.cursur > self.num_frames - 1:
            self.cursur = 0
        self.tl_slider.setValue(self.cursur)

    def on_play(self):
        if self.timer.isActive():
            self.timer.stop()
        else:
            self.timer.start(1000 / 25)

    def on_undo(self):
        if self.in_local_mode:
            if self.local_interaction is None:
                if len(self.local_interactions['interact']) > 1:
                    self.local_interactions[
                        'interact'] = self.local_interactions['interact'][:-1]
                else:
                    self.reset_this_interaction()
                self.local_interacted_mask = self.local_interactions[
                    'interact'][-1].predict()
            else:
                if self.local_interaction.can_undo():
                    self.local_interacted_mask = self.local_interaction.undo()
                else:
                    if len(self.local_interactions['interact']) > 1:
                        self.local_interaction = None
                    else:
                        self.reset_this_interaction()
                    self.local_interacted_mask = self.local_interactions[
                        'interact'][-1].predict()

            # Update visualization
            if len(self.local_vis_hist) > 0:
                # Might be empty if we are undoing the entire interaction
                self.local_vis_map, self.local_vis_alpha = self.local_vis_hist.pop(
                )
        else:
            if self.interaction is None:
                if len(self.this_frame_interactions) > 1:
                    self.this_frame_interactions = self.this_frame_interactions[:
                                                                                -1]
                    self.interacted_mask = self.this_frame_interactions[
                        -1].predict()
                else:
                    self.reset_this_interaction()
                    self.interacted_mask = self.processor.prob1[:, self.
                                                                cursur].clone(
                                                                )
            else:
                if self.interaction.can_undo():
                    self.interacted_mask = self.interaction.undo()
                else:
                    if len(self.this_frame_interactions) > 0:
                        self.interaction = None
                        self.interacted_mask = self.this_frame_interactions[
                            -1].predict()
                    else:
                        self.reset_this_interaction()
                        self.interacted_mask = self.processor.prob1[:, self.
                                                                    cursur].clone(
                                                                    )

            # Update visualization
            if len(self.vis_hist) > 0:
                # Might be empty if we are undoing the entire interaction
                self.vis_map, self.vis_alpha = self.vis_hist.pop()

        # Commit changes
        self.update_interacted_mask()

    def on_reset(self):
        # DO not edit prob -- we still need the mask diff
        self.processor.masks[self.cursur].zero_()
        self.processor.np_masks[self.cursur].fill(0)
        self.current_mask[self.cursur].fill(0)
        self.reset_this_interaction()
        self.show_current_frame()

    def on_zoom_plus(self):
        self.zoom_pixels -= 25
        self.zoom_pixels = max(50, self.zoom_pixels)
        self.update_minimap()

    def on_zoom_minus(self):
        self.zoom_pixels += 25
        self.zoom_pixels = min(self.zoom_pixels, 300)
        self.update_minimap()

    def set_navi_enable(self, boolean):
        self.zoom_p_button.setEnabled(boolean)
        self.zoom_m_button.setEnabled(boolean)
        self.run_button.setEnabled(boolean)
        self.tl_slider.setEnabled(boolean)
        self.play_button.setEnabled(boolean)
        self.lcd.setEnabled(boolean)

    def on_finish_local(self):
        self.complete_interaction()
        self.finish_local_button.setDisabled(True)
        self.in_local_mode = False
        self.set_navi_enable(True)

        # Push the combined local interactions as a global interaction
        if len(self.this_frame_interactions) > 0:
            prev_soft_mask = self.this_frame_interactions[-1].out_prob
        else:
            prev_soft_mask = self.processor.prob[1:, self.cursur]
        image = self.processor.images[:, self.cursur]

        self.interaction = LocalInteraction(
            image, prev_soft_mask, (self.height, self.width), self.local_bb,
            self.local_interactions['interact'][-1].out_prob,
            self.processor.pad, self.local_pad)
        self.interaction.storage = self.local_interactions
        self.interacted_mask = self.interaction.predict()
        self.complete_interaction()
        self.update_interacted_mask()
        self.show_current_frame()

        self.console_push_text('Finished local control.')

    def hit_number_key(self, number):
        if number == self.current_object:
            return
        self.current_object = number
        if self.fbrs_controller is not None:
            self.fbrs_controller.unanchor()
        self.console_push_text('Current object changed to %d!' % number)
        self.clear_brush()
        self.vis_brush(self.last_ex, self.last_ey)
        self.update_interact_vis()
        self.show_current_frame()

    def clear_brush(self):
        self.brush_vis_map.fill(0)
        self.brush_vis_alpha.fill(0)
        if self.local_brush_vis_map is not None:
            self.local_brush_vis_map.fill(0)
            self.local_brush_vis_alpha.fill(0)

    def vis_brush(self, ex, ey):
        if self.ctrl_key:
            # Visualize the control region
            lx = int(round(min(self.local_start[0], ex)))
            ux = int(round(max(self.local_start[0], ex)))
            ly = int(round(min(self.local_start[1], ey)))
            uy = int(round(max(self.local_start[1], ey)))
            self.brush_vis_map = cv2.rectangle(self.brush_vis_map, (lx, ly),
                                               (ux, uy), (128, 255, 128),
                                               thickness=-1)
            self.brush_vis_alpha = cv2.rectangle(self.brush_vis_alpha,
                                                 (lx, ly), (ux, uy),
                                                 0.5,
                                                 thickness=-1)
        else:
            # Visualize the brush (yeah I know)
            if self.in_local_mode:
                self.local_brush_vis_map = cv2.circle(
                    self.local_brush_vis_map, (int(round(ex)), int(round(ey))),
                    self.brush_size // 2 + 1,
                    color_map[self.current_object],
                    thickness=-1)
                self.local_brush_vis_alpha = cv2.circle(
                    self.local_brush_vis_alpha,
                    (int(round(ex)), int(round(ey))),
                    self.brush_size // 2 + 1,
                    0.5,
                    thickness=-1)
            else:
                self.brush_vis_map = cv2.circle(
                    self.brush_vis_map, (int(round(ex)), int(round(ey))),
                    self.brush_size // 2 + 1,
                    color_map[self.current_object],
                    thickness=-1)
                self.brush_vis_alpha = cv2.circle(
                    self.brush_vis_alpha, (int(round(ex)), int(round(ey))),
                    self.brush_size // 2 + 1,
                    0.5,
                    thickness=-1)

    def enter_local_control(self):
        self.in_local_mode = True
        lx = int(round(min(self.local_start[0], self.local_end[0])))
        ux = int(round(max(self.local_start[0], self.local_end[0])))
        ly = int(round(min(self.local_start[1], self.local_end[1])))
        uy = int(round(max(self.local_start[1], self.local_end[1])))

        # Reset variables
        self.local_bb = (lx, ux, ly, uy)
        self.local_interactions = {}
        self.local_interactions['interact'] = []
        self.local_interaction = None

        # Initial info
        if len(self.this_local_interactions) == 0:
            prev_soft_mask = self.processor.prob[1:, self.cursur]
        else:
            prev_soft_mask = self.this_local_interactions[-1].out_prob
        self.local_interactions['bounding_box'] = self.local_bb
        self.local_interactions['cursur'] = self.cursur
        init_interaction = CropperInteraction(
            self.processor.images[:, self.cursur], prev_soft_mask,
            self.processor.pad, self.local_bb)
        self.local_interactions['interact'].append(init_interaction)

        self.local_interacted_mask = init_interaction.out_mask
        self.local_torch_im = init_interaction.im_crop
        self.local_np_im = self.images[self.cursur][ly:uy + 1, lx:ux + 1, :]
        self.local_pad = init_interaction.pad

        # initialize the local visualization maps
        h, w = init_interaction.h, init_interaction.w
        self.local_vis_map = np.zeros((h, w, 3), dtype=np.uint8)
        self.local_vis_alpha = np.zeros((h, w, 1), dtype=np.float32)
        self.local_brush_vis_map = np.zeros((h, w, 3), dtype=np.uint8)
        self.local_brush_vis_alpha = np.zeros((h, w, 1), dtype=np.float32)
        self.local_vis_hist = deque(maxlen=100)
        self.local_height, self.local_width = h, w

        # Refresh self.viz
        self.minimap_in_local_drawn = False
        self.minimap_in_local = self.viz_with_stroke
        self.update_interacted_mask()
        self.finish_local_button.setEnabled(True)
        self.undo_button.setEnabled(False)
        self.set_navi_enable(False)

        self.console_push_text('Entered local control.')

    def on_press(self, event):
        if self.waiting_to_start:
            self.waiting_to_start = False
            self.algo_timer.start()
            self.user_timer.start()
            self.console_push_text('Timers started.')

        self.user_timer.pause()
        ex, ey = self.get_scaled_pos(event.x(), event.y())
        # Check for ctrl key
        modifiers = QApplication.keyboardModifiers()
        if not self.in_local_mode and modifiers == QtCore.Qt.ControlModifier:
            # Start specifying the local mode
            self.ctrl_key = True
        else:
            self.ctrl_key = False

        self.pressed = True
        self.right_click = (event.button() != 1)
        # Push last vis map into history
        if self.in_local_mode:
            self.local_vis_hist.append(
                (self.local_vis_map.copy(), self.local_vis_alpha.copy()))
        else:
            self.vis_hist.append((self.vis_map.copy(), self.vis_alpha.copy()))
        if self.ctrl_key:
            # Wrap up the last interaction
            self.complete_interaction()
            # Labeling a local control field
            self.local_start = ex, ey
        else:
            # Ordinary interaction (might be in local mode)
            if self.in_local_mode:
                if self.local_interaction is None:
                    prev_soft_mask = self.local_interactions['interact'][
                        -1].out_prob
                else:
                    prev_soft_mask = self.local_interaction.out_prob
                prev_hard_mask = self.local_max_mask
                image = self.local_torch_im
                h, w = self.local_height, self.local_width
            else:
                if self.interaction is None:
                    if len(self.this_frame_interactions) > 0:
                        prev_soft_mask = self.this_frame_interactions[
                            -1].out_prob
                    else:
                        prev_soft_mask = self.processor.prob[1:, self.cursur]
                else:
                    # Not used if the previous interaction is still valid
                    # Don't worry about stacking effects here
                    prev_soft_mask = self.interaction.out_prob
                prev_hard_mask = self.processor.masks[self.cursur]
                image = self.processor.images[:, self.cursur]
                h, w = self.height, self.width

            last_interaction = self.local_interaction if self.in_local_mode else self.interaction
            new_interaction = None
            if self.curr_interaction == 'Scribble':
                if last_interaction is None or type(
                        last_interaction) != ScribbleInteraction:
                    self.complete_interaction()
                    new_interaction = ScribbleInteraction(
                        image, prev_hard_mask, (h, w), self.s2m_controller,
                        self.num_objects)
            elif self.curr_interaction == 'Free':
                if last_interaction is None or type(
                        last_interaction) != FreeInteraction:
                    self.complete_interaction()
                    if self.in_local_mode:
                        new_interaction = FreeInteraction(
                            image, prev_soft_mask, (h, w), self.num_objects,
                            self.local_pad)
                    else:
                        new_interaction = FreeInteraction(
                            image, prev_soft_mask, (h, w), self.num_objects,
                            self.processor.pad)
                    new_interaction.set_size(self.brush_size)
            elif self.curr_interaction == 'Click':
                if (last_interaction is None
                        or type(last_interaction) != ClickInteraction
                        or last_interaction.tar_obj != self.current_object):
                    self.complete_interaction()
                    self.fbrs_controller.unanchor()
                    new_interaction = ClickInteraction(image, prev_soft_mask,
                                                       (h, w),
                                                       self.fbrs_controller,
                                                       self.current_object,
                                                       self.processor.pad)

            if new_interaction is not None:
                if self.in_local_mode:
                    self.local_interaction = new_interaction
                else:
                    self.interaction = new_interaction

        # Just motion it as the first step
        self.on_motion(event)
        self.user_timer.start()

    def on_motion(self, event):
        ex, ey = self.get_scaled_pos(event.x(), event.y())
        self.last_ex, self.last_ey = ex, ey
        self.clear_brush()
        # Visualize
        self.vis_brush(ex, ey)
        if self.pressed:
            if not self.ctrl_key:
                if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
                    obj = 0 if self.right_click else self.current_object
                    # Actually draw it if dragging
                    if self.in_local_mode:
                        self.local_vis_map, self.local_vis_alpha = self.local_interaction.push_point(
                            ex, ey, obj,
                            (self.local_vis_map, self.local_vis_alpha))
                    else:
                        self.vis_map, self.vis_alpha = self.interaction.push_point(
                            ex, ey, obj, (self.vis_map, self.vis_alpha))
        self.update_interact_vis()
        self.update_minimap()

    def update_interacted_mask(self):
        if self.in_local_mode:
            self.local_max_mask = torch.argmax(self.local_interacted_mask, 0)
            max_mask = unpad_3dim(self.local_max_mask, self.local_pad)
            self.local_np_mask = (max_mask.detach().cpu().numpy()[0]).astype(
                np.uint8)
        else:
            self.processor.update_mask_only(self.interacted_mask, self.cursur)
            self.current_mask[self.cursur] = self.processor.np_masks[
                self.cursur]
        self.show_current_frame()

    def complete_interaction(self):
        if self.in_local_mode:
            if self.local_interaction is not None:
                self.clear_visualization()
                self.local_interactions['interact'].append(
                    self.local_interaction)
                self.local_interaction = None
                self.undo_button.setDisabled(False)
        else:
            if self.interaction is not None:
                self.clear_visualization()
                self.interactions['annotated_frame'].append(self.cursur)
                self.interactions['interact'][self.cursur].append(
                    self.interaction)
                self.this_frame_interactions.append(self.interaction)
                self.interaction = None
                self.undo_button.setDisabled(False)

    def on_release(self, event):
        self.user_timer.pause()
        ex, ey = self.get_scaled_pos(event.x(), event.y())
        if self.ctrl_key:
            # Enter local control mode
            self.clear_visualization()
            self.local_end = ex, ey
            self.enter_local_control()
        else:
            self.console_push_text('Interaction %s at frame %d.' %
                                   (self.curr_interaction, self.cursur))
            # Ordinary interaction (might be in local mode)
            if self.in_local_mode:
                interaction = self.local_interaction
            else:
                interaction = self.interaction

            if self.curr_interaction == 'Scribble' or self.curr_interaction == 'Free':
                self.on_motion(event)
                interaction.end_path()
                if self.curr_interaction == 'Free':
                    self.clear_visualization()
            elif self.curr_interaction == 'Click':
                ex, ey = self.get_scaled_pos(event.x(), event.y())
                if self.in_local_mode:
                    self.local_vis_map, self.local_vis_alpha = interaction.push_point(
                        ex, ey, self.right_click,
                        (self.local_vis_map, self.local_vis_alpha))
                else:
                    self.vis_map, self.vis_alpha = interaction.push_point(
                        ex, ey, self.right_click,
                        (self.vis_map, self.vis_alpha))

            if self.in_local_mode:
                self.local_interacted_mask = interaction.predict()
            else:
                self.interacted_mask = interaction.predict()
            self.update_interacted_mask()

        self.pressed = self.ctrl_key = self.right_click = False
        self.undo_button.setDisabled(False)
        self.user_timer.start()

    def debug_pressed(self):
        self.debug_mask, self.interacted_mask = self.interacted_mask, self.debug_mask

        self.processor.update_mask_only(self.interacted_mask, self.cursur)
        self.current_mask[self.cursur] = self.processor.np_masks[self.cursur]
        self.show_current_frame()

    def wheelEvent(self, event):
        ex, ey = self.get_scaled_pos(event.x(), event.y())
        if self.curr_interaction == 'Free':
            self.brush_slider.setValue(self.brush_slider.value() +
                                       event.angleDelta().y() // 30)
        self.clear_brush()
        self.vis_brush(ex, ey)
        self.update_interact_vis()
        self.update_minimap()
Beispiel #4
0
class DAVISProcessor:
    """
    Acts as the junction between DAVIS interactive track and our inference_core
    """
    def __init__(self,
                 prop_net,
                 fuse_net,
                 s2m_net,
                 images,
                 num_objects,
                 device='cuda:0'):
        self.s2m_net = s2m_net.to(device, non_blocking=True)

        images, self.pad = pad_divide_by(images, 16, images.shape[-2:])
        self.device = device

        # Padded dimensions
        nh, nw = images.shape[-2:]
        self.nh, self.nw = nh, nw

        # True dimensions
        t = images.shape[1]
        h, w = images.shape[-2:]

        self.k = num_objects
        self.t, self.h, self.w = t, h, w

        self.interacted_count = 0
        self.davis_schedule = [2, 5, 7]

        self.processor = InferenceCore(prop_net,
                                       fuse_net,
                                       images,
                                       num_objects,
                                       mem_profile=0,
                                       device=device)

    def to_mask(self, scribble):
        # First we select the only frame with scribble
        all_scr = scribble['scribbles']
        for idx, s in enumerate(all_scr):
            if len(s) != 0:
                scribble['scribbles'] = [s]
                break

        # Pass to DAVIS to change the path to an array
        scr_mask = scribbles2mask(scribble, (self.h, self.w))[0]

        # Run our S2M
        kernel = np.ones((3, 3), np.uint8)
        mask = torch.zeros((self.k, 1, self.nh, self.nw),
                           dtype=torch.float32,
                           device=self.device)
        for ki in range(1, self.k + 1):
            p_srb = (scr_mask == ki).astype(np.uint8)
            p_srb = cv2.dilate(p_srb, kernel).astype(np.bool)

            n_srb = ((scr_mask != ki) * (scr_mask != -1)).astype(np.uint8)
            n_srb = cv2.dilate(n_srb, kernel).astype(np.bool)

            Rs = torch.from_numpy(np.stack(
                [p_srb, n_srb], 0)).unsqueeze(0).float().to(self.device)
            Rs, _ = pad_divide_by(Rs, 16, Rs.shape[-2:])

            # Use hard mask because we train S2M with such
            inputs = torch.cat([
                self.processor.get_image_buffered(idx),
                (self.processor.masks[idx] == ki).to(
                    self.device).float().unsqueeze(0), Rs
            ], 1)
            mask[ki - 1] = torch.sigmoid(self.s2m_net(inputs))
        mask = aggregate_wbg(mask, keep_bg=True, hard=True)
        return mask, idx

    def interact(self, scribble):
        mask, idx = self.to_mask(scribble)

        if self.interacted_count == self.davis_schedule[0]:
            # Finish the instant interaction loop for this frame
            self.davis_schedule = self.davis_schedule[1:]
            next_interact = None
            out_masks = self.processor.interact(mask, idx)
        else:
            next_interact = [idx]
            out_masks = self.processor.update_mask_only(mask, idx)

        self.interacted_count += 1

        # Trim paddings
        if self.pad[2] + self.pad[3] > 0:
            out_masks = out_masks[:, self.pad[2]:-self.pad[3], :]
        if self.pad[0] + self.pad[1] > 0:
            out_masks = out_masks[:, :, self.pad[0]:-self.pad[1]]

        return out_masks, next_interact, idx
Beispiel #5
0
# Start eval
for data in progressbar(test_loader,
                        max_value=len(test_loader),
                        redirect_stdout=True):

    rgb = data['rgb'].cuda()
    msk = data['gt'][0].cuda()
    info = data['info']
    name = info['name'][0]
    k = len(info['labels'][0])
    size = info['size_480p']

    torch.cuda.synchronize()
    process_begin = time.time()

    processor = InferenceCore(prop_model, rgb, k)
    processor.interact(msk[:, 0], 0, rgb.shape[1])

    # Do unpad -> upsample to original size
    out_masks = torch.zeros((processor.t, 1, *size),
                            dtype=torch.uint8,
                            device='cuda')
    for ti in range(processor.t):
        prob = processor.prob[:, ti]

        if processor.pad[2] + processor.pad[3] > 0:
            prob = prob[:, :, processor.pad[2]:-processor.pad[3], :]
        if processor.pad[0] + processor.pad[1] > 0:
            prob = prob[:, :, :, processor.pad[0]:-processor.pad[1]]

        prob = F.interpolate(prob, size, mode='bilinear', align_corners=False)