Exemplo n.º 1
0
    def interact(self, mask, frame_idx, end_idx, obj_idx):
        """
        mask - Input one-hot encoded mask WITHOUT the background class
        frame_idx, end_idx - Start and end idx of propagation
        obj_idx - list of object IDs that first appear on this frame
        """

        # In youtube mode, we interact with a subset of object id at a time
        mask, _ = pad_divide_by(mask.cuda(), 16)

        # update objects that have been labeled
        self.enabled_obj.extend(obj_idx)

        # Set other prob of mask regions to zero
        mask_regions = (mask[1:].sum(0) > 0.5)
        self.prob[:, frame_idx, mask_regions] = 0
        self.prob[obj_idx, frame_idx] = mask[obj_idx]

        self.prob[:, frame_idx] = aggregate_wbg(self.prob[1:, frame_idx], keep_bg=True)

        # KV pair for the interacting frame
        key_k, key_v = self.prop_net.memorize(self.images[:,frame_idx].cuda(), self.prob[self.enabled_obj,frame_idx].cuda())

        # Propagate
        self.do_pass(key_k, key_v, frame_idx, end_idx)
Exemplo n.º 2
0
    def do_pass(self, key_k, key_v, idx, end_idx):
        """
        key_k, key_v - Memory feature of the starting frame
        idx - Frame index of the starting frame
        end_idx - Frame index at which we stop the propagation
        """
        closest_ti = end_idx

        K, CK, _, H, W = key_k.shape
        _, CV, _, _, _ = key_v.shape

        keys = key_k
        values = key_v

        prev_in_mem = True
        prev_key = prev_value = None
        last_ti = idx

        # Note that we never reach closest_ti, just the frame before it
        this_range = range(idx + 1, closest_ti)
        end = closest_ti - 1

        for ti in this_range:
            if prev_in_mem:
                # if the previous frame has already been added to the memory bank
                this_k = keys
                this_v = values
            else:
                # append it to a temporary memory bank otherwise
                this_k = torch.cat([keys, prev_key], 2)
                this_v = torch.cat([values, prev_value], 2)
            query = self.get_query_kv_buffered(ti)
            out_mask = self.prop_net.segment_with_query(this_k, this_v, *query)

            out_mask = aggregate_wbg(out_mask, keep_bg=True)
            self.prob[:, ti] = out_mask

            if ti != end:
                # Memorize this frame
                prev_key, prev_value = self.prop_net.memorize(
                    self.images[:, ti].cuda(), out_mask[1:])
                if abs(ti - last_ti) >= self.mem_freq:
                    # Make the temporary memory permanent
                    keys = torch.cat([keys, prev_key], 2)
                    values = torch.cat([values, prev_value], 2)
                    last_ti = ti
                    prev_in_mem = True
                else:
                    prev_in_mem = False

        return closest_ti
Exemplo n.º 3
0
    def interact(self, mask, frame_idx, end_idx):
        """
        mask - Input one-hot encoded mask WITHOUT the background class
        frame_idx, end_idx - Start and end idx of propagation
        """
        mask, _ = pad_divide_by(mask.cuda(), 16)

        self.prob[:, frame_idx] = aggregate_wbg(mask, keep_bg=True)

        # KV pair for the interacting frame
        key_k, key_v = self.prop_net.memorize(self.images[:, frame_idx].cuda(),
                                              self.prob[1:, frame_idx].cuda())

        # Propagate
        self.do_pass(key_k, key_v, frame_idx, end_idx)
Exemplo n.º 4
0
    def fuse_one_frame(self, tc, tr, ti, mk16, qk16):
        assert(tc<ti<tr or tr<ti<tc)

        prob = torch.zeros((self.k, 1, self.nh, self.nw), dtype=torch.float32, device=self.device)

        # Compute linear coefficients
        nc = abs(tc-ti) / abs(tc-tr)
        nr = abs(tr-ti) / abs(tc-tr)
        dist = torch.FloatTensor([nc, nr]).to(self.device).unsqueeze(0)
        for k in range(1, self.k+1):
            attn_map = self.prop_net.get_attention(mk16[k-1:k], self.pos_mask_diff[k:k+1], self.neg_mask_diff[k:k+1], qk16)

            w = torch.sigmoid(self.fuse_net(self.get_image_buffered(ti), 
                    self.prob1[k:k+1,ti].to(self.device), self.prob2[k:k+1,ti].to(self.device), attn_map, dist))
            prob[k-1] = w 
        return aggregate_wbg(prob, keep_bg=True)
Exemplo n.º 5
0
    def do_pass(self,
                key_k,
                key_v,
                idx,
                left_limit,
                right_limit,
                forward=True):
        keys = key_k
        values = key_v
        prev_k = prev_v = None
        last_ti = idx

        Es = self.prob

        if forward:
            this_range = range(idx + 1, right_limit + 1)
            step = +1
            end = right_limit
        else:
            this_range = range(idx - 1, left_limit - 1, -1)
            step = -1
            end = left_limit

        for ti in this_range:
            if prev_k is not None:
                this_k = torch.cat([keys, prev_k], 2)
                this_v = torch.cat([values, prev_v], 2)
            else:
                this_k = keys
                this_v = values
            query = self.get_query_buf(ti)
            out_mask = self.prop_net.segment_with_query(this_k, this_v, *query)
            out_mask = aggregate_wbg(out_mask, keep_bg=True)

            Es[:, ti] = out_mask

            if ti != end:
                prev_k, prev_v = self.prop_net.memorize(
                    self.get_im(ti), out_mask[1:])
                if abs(ti - last_ti) >= self.mem_freq:
                    last_ti = ti
                    keys = torch.cat([keys, prev_k], 2)
                    values = torch.cat([values, prev_v], 2)
                    prev_k = prev_v = None
Exemplo n.º 6
0
    def interact_mask(self, mask, idx, left_limit, right_limit):

        mask, _ = pad_divide_by(mask, 16, mask.shape[-2:])
        mask = aggregate_wbg(mask, keep_bg=True)

        self.prob[:, idx] = mask
        key_k, key_v = self.prop_net.memorize(self.get_im(idx), mask[1:])

        self.do_pass(key_k, key_v, idx, left_limit, right_limit, True)
        self.do_pass(key_k, key_v, idx, left_limit, right_limit, False)

        # Prepare output
        out_prob = self.prob[:, :, 0, :, :]

        if self.pad[2] + self.pad[3] > 0:
            out_prob = out_prob[:, :, self.pad[2]:-self.pad[3], :]
        if self.pad[0] + self.pad[1] > 0:
            out_prob = out_prob[:, :, :, self.pad[0]:-self.pad[1]]

        return out_prob
Exemplo n.º 7
0
    def to_mask(self, scribble):
        # First we select the only frame with scribble
        all_scr = scribble['scribbles']
        for idx, s in enumerate(all_scr):
            if len(s) != 0:
                scribble['scribbles'] = [s]
                break

        # Pass to DAVIS to change the path to an array
        scr_mask = scribbles2mask(scribble, (self.h, self.w))[0]

        # Run our S2M
        kernel = np.ones((3, 3), np.uint8)
        mask = torch.zeros((self.k, 1, self.nh, self.nw),
                           dtype=torch.float32,
                           device=self.device)
        for ki in range(1, self.k + 1):
            p_srb = (scr_mask == ki).astype(np.uint8)
            p_srb = cv2.dilate(p_srb, kernel).astype(np.bool)

            n_srb = ((scr_mask != ki) * (scr_mask != -1)).astype(np.uint8)
            n_srb = cv2.dilate(n_srb, kernel).astype(np.bool)

            Rs = torch.from_numpy(np.stack(
                [p_srb, n_srb], 0)).unsqueeze(0).float().to(self.device)
            Rs, _ = pad_divide_by(Rs, 16, Rs.shape[-2:])

            # Use hard mask because we train S2M with such
            inputs = torch.cat([
                self.processor.get_image_buffered(idx),
                (self.processor.masks[idx] == ki).to(
                    self.device).float().unsqueeze(0), Rs
            ], 1)
            mask[ki - 1] = torch.sigmoid(self.s2m_net(inputs))
        mask = aggregate_wbg(mask, keep_bg=True, hard=True)
        return mask, idx
Exemplo n.º 8
0
    def do_pass(self, key_k, key_v, idx, forward=True, step_cb=None):
        """
        Do a complete pass that includes propagation and fusion
        key_k/key_v -  memory feature of the starting frame
        idx - Frame index of the starting frame
        forward - forward/backward propagation
        step_cb - Callback function used for GUI (progress bar) only
        """

        # Pointer in the memory bank
        num_certain_keys = self.certain_mem_k.shape[2]
        m_front = num_certain_keys

        # Determine the required size of the memory bank
        if forward:
            closest_ti = min([ti for ti in self.interacted if ti > idx] +
                             [self.t])
            total_m = (closest_ti - idx -
                       1) // self.mem_freq + 1 + num_certain_keys
        else:
            closest_ti = max([ti for ti in self.interacted if ti < idx] + [-1])
            total_m = (idx - closest_ti -
                       1) // self.mem_freq + 1 + num_certain_keys
        K, CK, _, H, W = key_k.shape
        _, CV, _, _, _ = key_v.shape

        # Pre-allocate keys/values memory
        keys = torch.empty((K, CK, total_m, H, W),
                           dtype=torch.float32,
                           device=self.device)
        values = torch.empty((K, CV, total_m, H, W),
                             dtype=torch.float32,
                             device=self.device)

        # Initial key/value passed in
        keys[:, :, 0:num_certain_keys] = self.certain_mem_k
        values[:, :, 0:num_certain_keys] = self.certain_mem_v
        prev_in_mem = True
        last_ti = idx

        # Note that we never reach closest_ti, just the frame before it
        if forward:
            this_range = range(idx + 1, closest_ti)
            step = +1
            end = closest_ti - 1
        else:
            this_range = range(idx - 1, closest_ti, -1)
            step = -1
            end = closest_ti + 1

        for ti in this_range:
            if prev_in_mem:
                this_k = keys[:, :, :m_front]
                this_v = values[:, :, :m_front]
            else:
                this_k = keys[:, :, :m_front + 1]
                this_v = values[:, :, :m_front + 1]
            query = self.get_query_kv_buffered(ti)
            out_mask = self.prop_net.segment_with_query(this_k, this_v, *query)

            out_mask = aggregate_wbg(out_mask, keep_bg=True)

            if ti != end:
                keys[:, :, m_front:m_front +
                     1], values[:, :,
                                m_front:m_front + 1] = self.prop_net.memorize(
                                    self.get_image_buffered(ti), out_mask[1:])
                if abs(ti - last_ti) >= self.mem_freq:
                    # Memorize the frame
                    m_front += 1
                    last_ti = ti
                    prev_in_mem = True
                else:
                    prev_in_mem = False

            # In-place fusion, maximizes the use of queried buffer
            # esp. for long sequence where the buffer will be flushed
            if (closest_ti != self.t) and (closest_ti != -1):
                self.prob[:, ti] = self.fuse_one_frame(
                    closest_ti, idx, ti, self.prob[:, ti], out_mask, key_k,
                    query[3]).to(self.result_dev, non_blocking=True)
            else:
                self.prob[:, ti] = out_mask.to(self.result_dev,
                                               non_blocking=True)

            # Callback function for the GUI
            if step_cb is not None:
                step_cb()

        return closest_ti
Exemplo n.º 9
0
    def do_pass(self, key_k, key_v, idx, end_idx):
        """
        key_k, key_v - Memory feature of the starting frame
        idx - Frame index of the starting frame
        end_idx - Frame index at which we stop the propagation
        """
        closest_ti = end_idx

        K, CK, _, H, W = key_k.shape
        _, CV, _, _, _ = key_v.shape

        for i, oi in enumerate(self.enabled_obj):
            if oi not in self.keys:
                self.keys[oi] = key_k[i:i+1]
                self.values[oi] = key_v[i:i+1]
            else:
                self.keys[oi] = torch.cat([self.keys[oi], key_k[i:i+1]], 2)
                self.values[oi] = torch.cat([self.values[oi], key_v[i:i+1]], 2)

        prev_in_mem = True
        prev_key = {}
        prev_value = {}
        last_ti = idx

        # Note that we never reach closest_ti, just the frame before it
        this_range = range(idx+1, closest_ti)
        step = +1
        end = closest_ti - 1

        for ti in this_range:
            if prev_in_mem:
                # if the previous frame has already been added to the memory bank
                this_k = self.keys
                this_v = self.values
            else:
                # append it to a temporary memory bank otherwise
                # everything has to be done independently for each object
                this_k = {}
                this_v = {}
                for i, oi in enumerate(self.enabled_obj):
                    this_k[oi] = torch.cat([self.keys[oi], prev_key[i:i+1]], 2)
                    this_v[oi] = torch.cat([self.values[oi], prev_value[i:i+1]], 2)
                
            query = self.get_query_kv_buffered(ti)

            out_mask = torch.cat([
                self.prop_net.segment_with_query(this_k[oi], this_v[oi], *query)
            for oi in self.enabled_obj], 0)

            out_mask = aggregate_wbg(out_mask, keep_bg=True)
            self.prob[0,ti] = out_mask[0]
            # output mapping to the full object id space
            for i, oi in enumerate(self.enabled_obj):
                self.prob[oi,ti] = out_mask[i+1]

            if ti != end:
                # memorize this frame
                prev_key, prev_value = self.prop_net.memorize(self.images[:,ti].cuda(), out_mask[1:])
                if abs(ti-last_ti) >= self.mem_freq:
                    for i, oi in enumerate(self.enabled_obj):
                        self.keys[oi] = torch.cat([self.keys[oi], prev_key[i:i+1]], 2)
                        self.values[oi] = torch.cat([self.values[oi], prev_value[i:i+1]], 2)
                    last_ti = ti
                    prev_in_mem = True
                else:
                    prev_in_mem = False

        return closest_ti
Exemplo n.º 10
0
 def predict(self):
     self.out_prob = self.controller.interact(self.image, self.prev_mask, self.drawn_map)
     self.out_mask = aggregate_wbg(self.out_prob, keep_bg=True, hard=True)
     return self.out_mask