예제 #1
0
파일: interaction.py 프로젝트: nnmhuy/MiVOS
    def __init__(self, image, prev_mask, pad, bounding_box):
        lx, ux, ly, uy = bounding_box
        true_size = (uy-ly+1, ux-lx+1)
        super().__init__(image, prev_mask, true_size, None)

        self.bounding_box = bounding_box # UN-PADDED
        unpad_prev_mask = unpad(self.prev_mask, pad)
        self.out_prob = unpad_prev_mask[:, :, ly:uy+1, lx:ux+1]
        self.out_prob, self.pad = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
        self.out_mask = aggregate_sbg(self.out_prob, keep_bg=True)

        unpad_image = unpad(self.image, pad)
        self.im_crop = unpad_image[:, :, ly:uy+1, lx:ux+1]
        self.im_crop, _ = pad_divide_by(self.im_crop, 16, self.im_crop.shape[-2:])
예제 #2
0
    def interact(self, mask, frame_idx, end_idx, obj_idx):
        """
        mask - Input one-hot encoded mask WITHOUT the background class
        frame_idx, end_idx - Start and end idx of propagation
        obj_idx - list of object IDs that first appear on this frame
        """

        # In youtube mode, we interact with a subset of object id at a time
        mask, _ = pad_divide_by(mask.cuda(), 16)

        # update objects that have been labeled
        self.enabled_obj.extend(obj_idx)

        # Set other prob of mask regions to zero
        mask_regions = (mask[1:].sum(0) > 0.5)
        self.prob[:, frame_idx, mask_regions] = 0
        self.prob[obj_idx, frame_idx] = mask[obj_idx]

        self.prob[:, frame_idx] = aggregate_wbg(self.prob[1:, frame_idx], keep_bg=True)

        # KV pair for the interacting frame
        key_k, key_v = self.prop_net.memorize(self.images[:,frame_idx].cuda(), self.prob[self.enabled_obj,frame_idx].cuda())

        # Propagate
        self.do_pass(key_k, key_v, frame_idx, end_idx)
예제 #3
0
    def __init__(self,
                 prop_net,
                 fuse_net,
                 s2m_net,
                 images,
                 num_objects,
                 device='cuda:0'):
        self.s2m_net = s2m_net.to(device, non_blocking=True)

        images, self.pad = pad_divide_by(images, 16, images.shape[-2:])
        self.device = device

        # Padded dimensions
        nh, nw = images.shape[-2:]
        self.nh, self.nw = nh, nw

        # True dimensions
        t = images.shape[1]
        h, w = images.shape[-2:]

        self.k = num_objects
        self.t, self.h, self.w = t, h, w

        self.interacted_count = 0
        self.davis_schedule = [2, 5, 7]

        self.processor = InferenceCore(prop_net,
                                       fuse_net,
                                       images,
                                       num_objects,
                                       mem_profile=0,
                                       device=device)
예제 #4
0
    def interact(self, mask, idx, total_cb=None, step_cb=None):
        """
        Interact -> Propagate -> Fuse

        mask - One-hot mask of the interacted frame, background included
        idx - Frame index of the interacted frame
        total_cb, step_cb - Callback functions for the GUI

        Return: all mask results in np format for DAVIS evaluation
        """
        self.interacted.add(idx)

        mask = mask.to(self.device)
        mask, _ = pad_divide_by(mask, 16, mask.shape[-2:])
        self.mask_diff = mask - self.prob[:, idx].to(self.device)
        self.pos_mask_diff = self.mask_diff.clamp(0, 1)
        self.neg_mask_diff = (-self.mask_diff).clamp(0, 1)

        self.prob[:, idx] = mask
        key_k, key_v = self.prop_net.memorize(self.get_image_buffered(idx),
                                              mask[1:])

        if self.certain_mem_k is None:
            self.certain_mem_k = key_k
            self.certain_mem_v = key_v
        else:
            self.certain_mem_k = torch.cat([self.certain_mem_k, key_k], 2)
            self.certain_mem_v = torch.cat([self.certain_mem_v, key_v], 2)

        if total_cb is not None:
            # Finds the total num. frames to process
            front_limit = min([ti for ti in self.interacted if ti > idx] +
                              [self.t])
            back_limit = max([ti for ti in self.interacted if ti < idx] + [-1])
            total_num = front_limit - back_limit - 2  # -1 for shift, -1 for center frame
            total_cb(total_num)

        self.do_pass(key_k, key_v, idx, True, step_cb=step_cb)
        self.do_pass(key_k, key_v, idx, False, step_cb=step_cb)

        # This is a more memory-efficient argmax
        for ti in range(self.t):
            self.masks[ti] = torch.argmax(self.prob[:, ti], dim=0)
        out_masks = self.masks

        # Trim paddings
        if self.pad[2] + self.pad[3] > 0:
            out_masks = out_masks[:, :, self.pad[2]:-self.pad[3], :]
        if self.pad[0] + self.pad[1] > 0:
            out_masks = out_masks[:, :, :, self.pad[0]:-self.pad[1]]

        self.np_masks = (out_masks.detach().cpu().numpy()[:,
                                                          0]).astype(np.uint8)

        return self.np_masks
예제 #5
0
    def __init__(self, model, image, mask):
        self.model = model

        self.image = im_normalization(TF.to_tensor(image)).unsqueeze(0).cuda()
        self.mask = TF.to_tensor(mask).unsqueeze(0).cuda()

        h, w = self.image.shape[-2:]
        self.image, self.pad = pad_divide_by(self.image, 16)
        self.mask, _ = pad_divide_by(self.mask, 16)
        self.last_mask = None

        # Positive and negative scribbles
        self.p_srb = np.zeros((h, w), dtype=np.uint8)
        self.n_srb = np.zeros((h, w), dtype=np.uint8)

        # Used for drawing
        self.pressed = False
        self.last_ex = self.last_ey = None
        self.positive_mode = True
        self.need_update = True
예제 #6
0
    def interact(self, mask, frame_idx, end_idx):
        """
        mask - Input one-hot encoded mask WITHOUT the background class
        frame_idx, end_idx - Start and end idx of propagation
        """
        mask, _ = pad_divide_by(mask.cuda(), 16)

        self.prob[:, frame_idx] = aggregate_wbg(mask, keep_bg=True)

        # KV pair for the interacting frame
        key_k, key_v = self.prop_net.memorize(self.images[:, frame_idx].cuda(),
                                              self.prob[1:, frame_idx].cuda())

        # Propagate
        self.do_pass(key_k, key_v, frame_idx, end_idx)
예제 #7
0
    def interact(self, image, prev_mask, scr_mask):
        image = image.to(self.device, non_blocking=True)        

        h, w = image.shape[-2:]
        unaggre_mask = torch.zeros((self.num_objects, 1, h, w), dtype=torch.float32, device=image.device)

        for ki in range(1, self.num_objects+1):
            p_srb = (scr_mask==ki).astype(np.uint8)
            n_srb = ((scr_mask!=ki) * (scr_mask!=self.ignore_class)).astype(np.uint8)

            Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(image.device)
            Rs, _ = pad_divide_by(Rs, 16, Rs.shape[-2:])

            inputs = torch.cat([image, (prev_mask==ki).float().unsqueeze(0), Rs], 1)
            unaggre_mask[ki-1] = torch.sigmoid(self.s2m_net(inputs))

        return unaggre_mask
예제 #8
0
    def __init__(self, prop_net: PropagationNetwork, images, mem_freq):
        self.prop_net = prop_net
        self.mem_freq = mem_freq

        # True dimensions
        t = images.shape[1]
        h, w = images.shape[-2:]

        # Pad each side to multiple of 16
        images, self.pad = pad_divide_by(images, 16, images.shape[-2:])
        # Padded dimensions
        nh, nw = images.shape[-2:]

        self.images = images
        self.device = self.images.device

        self.t, self.h, self.w = t, h, w
        self.nh, self.nw = nh, nw
예제 #9
0
파일: interaction.py 프로젝트: nnmhuy/MiVOS
    def __init__(self, image, prev_mask, true_size, bounding_box, region_prob, pad, local_pad):
        super().__init__(image, prev_mask, true_size, None)
        lx, ux, ly, uy = bounding_box
        self.out_prob = unpad(self.prev_mask, pad)
        region_prob = unpad(region_prob, local_pad)

        # Trim the margin since results at the boundary are not that trustworthy
        if (ux-lx) > 6 and (uy-ly) > 6:
            lx += 3
            ux -= 3
            ly += 3
            uy -= 3
            self.out_prob[:,:,ly:uy+1, lx:ux+1] = region_prob[:,:,3:-3,3:-3]
        else:
            self.out_prob[:,:,ly:uy+1, lx:ux+1] = region_prob
        self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
        self.out_mask = aggregate_sbg(self.out_prob, keep_bg=True)
        self.storage = None # Might be used outside
예제 #10
0
    def interact_mask(self, mask, idx, left_limit, right_limit):

        mask, _ = pad_divide_by(mask, 16, mask.shape[-2:])
        mask = aggregate_wbg(mask, keep_bg=True)

        self.prob[:, idx] = mask
        key_k, key_v = self.prop_net.memorize(self.get_im(idx), mask[1:])

        self.do_pass(key_k, key_v, idx, left_limit, right_limit, True)
        self.do_pass(key_k, key_v, idx, left_limit, right_limit, False)

        # Prepare output
        out_prob = self.prob[:, :, 0, :, :]

        if self.pad[2] + self.pad[3] > 0:
            out_prob = out_prob[:, :, self.pad[2]:-self.pad[3], :]
        if self.pad[0] + self.pad[1] > 0:
            out_prob = out_prob[:, :, :, self.pad[0]:-self.pad[1]]

        return out_prob
예제 #11
0
    def run_s2m(self):
        # Convert scribbles to tensors
        Rsp = torch.from_numpy(self.p_srb).unsqueeze(0).unsqueeze(0).float().cuda()
        Rsn = torch.from_numpy(self.n_srb).unsqueeze(0).unsqueeze(0).float().cuda()
        Rs = torch.cat([Rsp, Rsn], 1)
        Rs, _ = pad_divide_by(Rs, 16)

        # Use the network to do stuff
        inputs = torch.cat([self.image, self.mask, Rs], 1)
        _, mask = aggregate(torch.sigmoid(net(inputs)))

        # We don't overwrite current mask until commit
        self.last_mask = mask
        np_mask = (mask.detach().cpu().numpy()[0,0] * 255).astype(np.uint8)

        if self.pad[2]+self.pad[3] > 0:
            np_mask = np_mask[self.pad[2]:-self.pad[3],:]
        if self.pad[0]+self.pad[1] > 0:
            np_mask = np_mask[:,self.pad[0]:-self.pad[1]]

        return np_mask
예제 #12
0
    def __init__(self, prop_net:PropagationNetwork, images, num_objects, mem_freq=5):
        self.prop_net = prop_net
        self.mem_freq = mem_freq

        # True dimensions
        t = images.shape[1]
        h, w = images.shape[-2:]

        # Pad each side to multiple of 16
        images, self.pad = pad_divide_by(images, 16)
        # Padded dimensions
        nh, nw = images.shape[-2:]

        self.images = images
        self.device = 'cuda'

        self.k = num_objects
        self.masks = torch.zeros((t, 1, nh, nw), dtype=torch.uint8, device=self.device)
        self.out_masks = np.zeros((t, h, w), dtype=np.uint8)

        # Background included, not always consistent (i.e. sum up to 1)
        self.prob = torch.zeros((self.k+1, t, 1, nh, nw), dtype=torch.float32, device=self.device)
        self.prob[0] = 1e-7

        self.t, self.h, self.w = t, h, w
        self.nh, self.nw = nh, nw
        self.kh = self.nh//16
        self.kw = self.nw//16

        # The keys/values are always presevered in YouTube testing
        # the reason is that we still consider it as a single propagation pass
        # just that some objects are arriving later than usual
        self.keys = dict()
        self.values = dict()

        # list of objects with usable memory
        self.enabled_obj = []
예제 #13
0
    def __init__(self,
                 prop_net: PropagationNetwork,
                 images,
                 num_objects,
                 mem_freq=5):
        self.prop_net = prop_net
        self.mem_freq = mem_freq

        # True dimensions
        t = images.shape[1]
        h, w = images.shape[-2:]

        # Pad each side to multiple of 16
        images, self.pad = pad_divide_by(images, 16)
        # Padded dimensions
        nh, nw = images.shape[-2:]

        self.images = images
        self.device = 'cuda'

        self.k = num_objects
        self.masks = torch.zeros((t, 1, nh, nw),
                                 dtype=torch.uint8,
                                 device=self.device)
        self.out_masks = np.zeros((t, h, w), dtype=np.uint8)

        # Background included, not always consistent (i.e. sum up to 1)
        self.prob = torch.zeros((self.k + 1, t, 1, nh, nw),
                                dtype=torch.float32,
                                device=self.device)
        self.prob[0] = 1e-7

        self.t, self.h, self.w = t, h, w
        self.nh, self.nw = nh, nw
        self.kh = self.nh // 16
        self.kw = self.nw // 16
예제 #14
0
    def to_mask(self, scribble):
        # First we select the only frame with scribble
        all_scr = scribble['scribbles']
        for idx, s in enumerate(all_scr):
            if len(s) != 0:
                scribble['scribbles'] = [s]
                break

        # Pass to DAVIS to change the path to an array
        scr_mask = scribbles2mask(scribble, (self.h, self.w))[0]

        # Run our S2M
        kernel = np.ones((3, 3), np.uint8)
        mask = torch.zeros((self.k, 1, self.nh, self.nw),
                           dtype=torch.float32,
                           device=self.device)
        for ki in range(1, self.k + 1):
            p_srb = (scr_mask == ki).astype(np.uint8)
            p_srb = cv2.dilate(p_srb, kernel).astype(np.bool)

            n_srb = ((scr_mask != ki) * (scr_mask != -1)).astype(np.uint8)
            n_srb = cv2.dilate(n_srb, kernel).astype(np.bool)

            Rs = torch.from_numpy(np.stack(
                [p_srb, n_srb], 0)).unsqueeze(0).float().to(self.device)
            Rs, _ = pad_divide_by(Rs, 16, Rs.shape[-2:])

            # Use hard mask because we train S2M with such
            inputs = torch.cat([
                self.processor.get_image_buffered(idx),
                (self.processor.masks[idx] == ki).to(
                    self.device).float().unsqueeze(0), Rs
            ], 1)
            mask[ki - 1] = torch.sigmoid(self.s2m_net(inputs))
        mask = aggregate_wbg(mask, keep_bg=True, hard=True)
        return mask, idx
예제 #15
0
파일: interaction.py 프로젝트: nnmhuy/MiVOS
 def predict(self):
     self.out_prob = torch.from_numpy(self.drawn_map).float().cuda()
     self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
     self.out_mask = aggregate_sbg(self.out_prob, keep_bg=True)
     return self.out_mask
예제 #16
0
    def __init__(self,
                 prop_net: PropagationNetwork,
                 fuse_net: FusionNet,
                 images,
                 num_objects,
                 mem_profile=0,
                 mem_freq=5,
                 device='cuda:0'):
        self.prop_net = prop_net.to(device, non_blocking=True)
        if fuse_net is not None:
            self.fuse_net = fuse_net.to(device, non_blocking=True)
        self.mem_profile = mem_profile
        self.mem_freq = mem_freq
        self.device = device

        if mem_profile == 0:
            self.data_dev = device
            self.result_dev = device
            self.q_buf_size = 105
            self.i_buf_size = -1  # no need to buffer image
        elif mem_profile == 1:
            self.data_dev = 'cpu'
            self.result_dev = device
            self.q_buf_size = 105
            self.i_buf_size = 105
        elif mem_profile == 2:
            self.data_dev = 'cpu'
            self.result_dev = 'cpu'
            self.q_buf_size = 3
            self.i_buf_size = 3
        else:
            self.data_dev = 'cpu'
            self.result_dev = 'cpu'
            self.q_buf_size = 1
            self.i_buf_size = 1

        # True dimensions
        t = images.shape[1]
        h, w = images.shape[-2:]
        self.k = num_objects

        # Pad each side to multiples of 16
        self.images, self.pad = pad_divide_by(images, 16, images.shape[-2:])
        # Padded dimensions
        nh, nw = self.images.shape[-2:]
        self.images = self.images.to(self.data_dev, non_blocking=False)

        # These two store the same information in different formats
        self.masks = torch.zeros((t, 1, nh, nw),
                                 dtype=torch.uint8,
                                 device=self.result_dev)
        self.np_masks = np.zeros((t, h, w), dtype=np.uint8)

        # Object probabilities, background included
        self.prob = torch.zeros((self.k + 1, t, 1, nh, nw),
                                dtype=torch.float32,
                                device=self.result_dev)
        self.prob[0] = 1e-7

        self.t, self.h, self.w = t, h, w
        self.nh, self.nw = nh, nw
        self.kh = self.nh // 16
        self.kw = self.nw // 16

        self.query_buf = {}
        self.image_buf = {}
        self.interacted = set()

        self.certain_mem_k = None
        self.certain_mem_v = None
예제 #17
0
def image_to_tensor(im):
    im = im_transform(im).unsqueeze(0)
    return im.cuda()


# Reading stuff
src_image = Image.open(args.src_image).convert('RGB')
tar_image = Image.open(args.tar_image).convert('RGB')
src_im_th = image_to_tensor(src_image)
tar_im_th = image_to_tensor(tar_image)
""" 
Compute W
"""
# Inputs need to have dimensions as multiples of 16
src_im_th, pads = pad_divide_by(src_im_th, 16)
tar_im_th, _ = pad_divide_by(tar_im_th, 16)

# Mask input is not crucial to getting a good correspondence
# we are just using an empty mask here
b, _, h, w = src_im_th.shape
empty_mask = torch.zeros((b, 1, h, w), device=src_im_th.device)

# We can precompute the affinity matrix (H/16 * W/16) * (H/16 * W/16)
# 16 is the encoder stride
qk16 = corr_net.get_query_key(tar_im_th)
mk16 = corr_net.get_mem_key(src_im_th, empty_mask, empty_mask)
W = corr_net.get_W(mk16, qk16)

# Generate the transfer mask
# This mask is considered as our "feature" to be transferred using the affinity matrix