def merge_result(result):
    base_result = result[0]
    final_result = {}

    for i in range(1, len(result)):
        ious = {}
        r = result[i]
        for es, ins in r:
            for bs, bi in base_result:
                es = torch.round(es)
                bs = torch.round(bs)
                iou = get_video_mIoU(es, bs)
                ious.setdefault(bi, {}).setdefault(ins, iou)

        for bs, bi in base_result:
            final_result.setdefault(bi, []).append(bs)
            if bi in ious.keys():
                best_match_iou = max(ious.get(bi).values())
                if best_match_iou >= 0.9:
                    for k, v in ious.get(bi).items():
                        if v == best_match_iou:
                            matched = k
                    for es, ins in r:
                        if ins == matched:
                            final_result[bi].append(es)
    fr = []
    for k, v in final_result.items():
        ins = k
        es = torch.mean(torch.cat([a.unsqueeze(0) for a in v]), dim=0)
        fr.append((es, ins))
    return fr
def Run_video(model, Fs, seg_results, num_frames, Mem_every=None, model_name='standard'):
    seg_result_idx = [i[3] for i in seg_results]

    instance_idx = 1
    b, c, T, h, w = Fs.shape
    results = []

    if np.all([len(i[0]) == 0 for i in seg_results]):
        print('No segmentation result of solo!')
        pred = torch.zeros((b, 1, T, h, w)).float().cuda()
        return [(pred, 1)]

    while True:
        if np.all([len(i[0]) == 0 for i in seg_results]):
            print('Run video over!')
            break
        if instance_idx > MAX_NUM:
            print('Max instance number!')
            break
        start_frame_idx = np.argmax([max(i[2]) if i[2] != [] else 0 for i in seg_results])
        start_frame = seg_result_idx[start_frame_idx]
        start_mask = seg_results[start_frame_idx][0][0].astype(np.uint8)
        # start_mask = cv2.resize(start_mask, (w, h))
        start_mask = torch.from_numpy(start_mask).cuda()

        if model_name in ('enhanced', 'enhanced_motion'):
            Os = torch.zeros((b, c, int(h / 4), int(w / 4)))
            first_frame = Fs[:, :, start_frame]
            first_mask = start_mask.cpu()
            if len(first_mask.shape) == 2:
                first_mask = first_mask.unsqueeze(0).unsqueeze(0)
            elif len(first_mask.shape) == 3:
                first_mask = first_mask.unsqueeze(0)
            first_frame = first_frame * first_mask.repeat(1, 3, 1, 1).type(torch.float)
            for i in range(b):
                mask_ = first_mask[i]
                mask_ = mask_.squeeze(0).cpu().numpy().astype(np.uint8)
                assert np.any(mask_)
                x, y, w_, h_ = cv2.boundingRect(mask_)
                patch = first_frame[i, :, y:(y + h_), x:(x + w_)].cpu().numpy()
                patch = patch.transpose(1, 2, 0)
                patch = cv2.resize(patch, (int(w / 4), int(h / 4)))
                patch = patch.transpose(2, 0, 1)
                patch = torch.from_numpy(patch)
                Os[i, :, :, :] = patch

        if model_name == 'varysize':
            oss = []
            first_frame = Fs[:, :, start_frame]
            first_mask = start_mask.cpu()
            if len(first_mask.shape) == 2:
                first_mask = first_mask.unsqueeze(0).unsqueeze(0)
            elif len(first_mask.shape) == 3:
                first_mask = first_mask.unsqueeze(0)
            first_frame = first_frame * first_mask.repeat(1, 3, 1, 1).type(torch.float)
            for i in range(b):
                mask_ = first_mask[i]
                mask_ = mask_.squeeze(0).cpu().numpy().astype(np.uint8)
                assert np.any(mask_)
                x, y, w_, h_ = cv2.boundingRect(mask_)
                patch = first_frame[i, :, y:(y + h_), x:(x + w_)].cpu().numpy()
                Os = torch.zeros((1, c, h_, w_))
                patch = patch.transpose(1, 2, 0)
                patch = patch.transpose(2, 0, 1)
                patch = torch.from_numpy(patch)
                Os[0, :, :, :] = patch
                os.append(Os)

        Es = torch.zeros((b, 1, T, h, w)).float().cuda()
        Es[:, :, start_frame] = start_mask
        # to_memorize = [int(i) for i in np.arange(start_frame, num_frames, step=Mem_every)]
        to_memorize = [start_frame]
        for t in range(start_frame + 1, num_frames):  # frames after
            # memorize
            pre_key, pre_value = model([Fs[:, :, t - 1], Es[:, :, t - 1]])
            pre_key = pre_key.unsqueeze(2)
            pre_value = pre_value.unsqueeze(2)

            if t - 1 == start_frame:  # the first frame
                this_keys_m, this_values_m = pre_key, pre_value
            else:  # other frame
                this_keys_m = torch.cat([keys, pre_key], dim=2)
                this_values_m = torch.cat([values, pre_value], dim=2)

            # segment
            if model_name == 'enhanced':
                logits, _, _ = model([Fs[:, :, t], Os, this_keys_m, this_values_m])
            elif model_name == 'motion':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m, Es[:, :, t - 1]])
            elif model_name == 'aspp':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m, torch.round(Es[:, :, t - 1])])
            elif model_name == 'sp':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m, torch.round(Es[:, :, t - 1])])
            elif model_name == 'standard':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m])
            elif model_name == 'enhanced_motion':
                logits, _, _ = model([Fs[:, :, t], Os, this_keys_m, this_values_m, torch.round(Es[:, :, t - 1])])
            elif model_name == 'varysize':
                logits, _, _ = model([Fs[:, :, t], oss, this_keys_m, this_values_m])
            else:
                raise NotImplementedError
            em = F.softmax(logits, dim=1)[:, 1]  # B h w
            Es[:, 0, t] = em

            # check solo result
            pred = torch.round(em.float())
            if MODE == 'offline':
                save_path = os.path.join(INTER_PATH, 'STM', '{}_{}.png'.format(instance_idx ,t + 1))
                img_array = pred.cpu().squeeze().numpy().astype(np.uint8)
                img_s = Image.fromarray(img_array)
                img_s.putpalette(PALETTE)
                img_s.save(save_path)
            if t in seg_result_idx:
                idx = seg_result_idx.index(t)
                this_frame_results = seg_results[idx]
                masks = this_frame_results[0]
                ious = []
                for mask in masks:
                    mask = mask.astype(np.uint8)
                    mask = torch.from_numpy(mask)
                    iou = get_video_mIoU(pred, mask)
                    ious.append(iou)
                if ious != []:
                    ious = np.array(ious)
                    reserve = list(range(len(ious)))
                    if sum(ious >= IOU1) >= 1:
                        same_idx = np.argmax(ious)
                        mask = torch.from_numpy(masks[same_idx]).cuda()
                        # if get_video_mIoU(mask, torch.round(Es[:, 0, t - 1])) \
                        #     > get_video_mIoU(pred, torch.round(Es[:, 0, t - 1])):
                        Es[:, 0, t] = mask
                        reserve.remove(same_idx)
                        # if abs(to_memorize[-1] - t) >= TO_MEMORY_MIN_INTERVAL:
                        to_memorize.append(t)

                    # for i, iou in enumerate(ious):
                    #     if iou >= IOU2:
                    #         if i in reserve:
                    #             reserve.remove(i)

                    reserve_result = []
                    for n in range(3):
                        reserve_result.append([this_frame_results[n][i] for i in reserve])
                    reserve_result.append(this_frame_results[3])
                    seg_results[idx] = reserve_result

            # update key and value
            if t - 1 in to_memorize:
                keys, values = this_keys_m, this_values_m

        # to_memorize = [start_frame - int(i) for i in np.arange(0, start_frame + 1, step=Mem_every)]
        to_memorize = [start_frame]
        for t in list(range(0, start_frame))[::-1]:  # frames before
            # memorize
            pre_key, pre_value = model([Fs[:, :, t + 1], Es[:, :, t + 1]])
            pre_key = pre_key.unsqueeze(2)
            pre_value = pre_value.unsqueeze(2)

            if t + 1 == start_frame:  # the first frame
                this_keys_m, this_values_m = pre_key, pre_value
            else:  # other frame
                this_keys_m = torch.cat([keys, pre_key], dim=2)
                this_values_m = torch.cat([values, pre_value], dim=2)

            # segment
            if model_name == 'enhanced':
                logits, _, _ = model([Fs[:, :, t], Os, this_keys_m, this_values_m])
            elif model_name == 'motion':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m, Es[:, :, t + 1]])
            elif model_name == 'aspp':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m, torch.round(Es[:, :, t + 1])])
            elif model_name == 'sp':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m, torch.round(Es[:, :, t + 1])])
            elif model_name == 'standard':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m])
            elif model_name == 'enhanced_motion':
                logits, _, _ = model([Fs[:, :, t], Os, this_keys_m, this_values_m, torch.round(Es[:, :, t + 1])])
            elif model_name == 'varysize':
                logits, _, _ = model([Fs[:, :, t], oss, this_keys_m, this_values_m])
            else:
                raise NotImplementedError
            em = F.softmax(logits, dim=1)[:, 1]  # B h w
            Es[:, 0, t] = em

            # check solo result
            pred = torch.round(em.float())
            if MODE == 'offline':
                save_path = os.path.join(INTER_PATH, 'STM', '{}_{}.png'.format(instance_idx ,t + 1))
                img_array = pred.cpu().squeeze().numpy().astype(np.uint8)
                img_s = Image.fromarray(img_array)
                img_s.putpalette(PALETTE)
                img_s.save(save_path)
            if t in seg_result_idx:
                idx = seg_result_idx.index(t)
                this_frame_results = seg_results[idx]
                masks = this_frame_results[0]
                ious = []
                for mask in masks:
                    mask = mask.astype(np.uint8)
                    mask = torch.from_numpy(mask)
                    iou = get_video_mIoU(pred, mask)
                    ious.append(iou)
                if ious != []:
                    ious = np.array(ious)
                    reserve = list(range(len(ious)))
                    if sum(ious >= IOU1) >= 1:
                        same_idx = np.argmax(ious)
                        mask = torch.from_numpy(masks[same_idx]).cuda()
                        # if get_video_mIoU(mask, torch.round(Es[:, 0, t + 1])) \
                        #         > get_video_mIoU(pred, torch.round(Es[:, 0, t + 1])):
                        Es[:, 0, t] = mask
                        reserve.remove(same_idx)
                        # if abs(to_memorize[-1] - t) >= TO_MEMORY_MIN_INTERVAL:
                        to_memorize.append(t)

                    # for i, iou in enumerate(ious):
                    #     if iou >= IOU2:
                    #         if i in reserve:
                    #             reserve.remove(i)

                    reserve_result = []
                    for n in range(3):
                        reserve_result.append([this_frame_results[n][i] for i in reserve])
                    reserve_result.append(this_frame_results[3])
                    seg_results[idx] = reserve_result

            # update key and value
            if t + 1 in to_memorize:
                keys, values = this_keys_m, this_values_m

        for j in range(3):
            seg_results[start_frame_idx][j].pop(0)

        # pred = torch.round(Es.float())
        results.append((Es, instance_idx))

        instance_idx += 1

    return results
Beispiel #3
0
def online_learning():
    print('online learning...')
    Testset = TIANCHI_FUSAI(DATA_ROOT,
                            imset='test.txt',
                            target_size=OL_TARGET_SHAPE)
    print('Total test videos: {}'.format(len(Testset)))
    Testloader = data.DataLoader(Testset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 pin_memory=True)

    model = nn.DataParallel(MODEL)

    optimizer = torch.optim.Adam(model.parameters(), OL_LR, betas=(0.9, 0.99))

    print('Loading weights:', MODEL_PATH)
    model_ = torch.load(MODEL_PATH)
    if 'state_dict' in model_.keys():
        state_dict = model_['state_dict']
    else:
        state_dict = model_
    model.load_state_dict(state_dict, strict=True)

    if 'optimizer' in model_.keys():
        try:
            optimizer.load_state_dict(model_['optimizer'])
        except Exception as e:
            print(e)

    if torch.cuda.is_available():
        model.cuda()
    # model.eval()  # turn-off BN
    model.train()
    # freeze bn
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()

    print('Start online learning...')
    progressbar = tqdm.tqdm(Testloader)

    for V in progressbar:
        Fs, info = V
        seq_name = info['name'][0]
        ori_shape = info['ori_shape']
        target_shape = info['target_shape']
        target_shape = (target_shape[0].cpu().numpy()[0],
                        target_shape[1].cpu().numpy()[0])
        num_frames = info['num_frames'][0].item()
        if '_' in seq_name:
            video_name = seq_name.split('_')[0]
        else:
            video_name = seq_name
        seg_results = mask_inference(video_name, OL_TARGET_SHAPE, 2)
        ol_clip_frames = 3
        seg_result_idx = [i[3] for i in seg_results]
        start_frame_idx = np.argmax(
            [max(i[2]) if i[2] != [] else 0 for i in seg_results])
        start_frame = seg_result_idx[start_frame_idx]
        start_mask = seg_results[start_frame_idx][0][0].astype(np.uint8)

        Ms = torch.empty((1, 1, ol_clip_frames) +
                         OL_TARGET_SHAPE[::-1]).cuda().long()
        Ps = torch.empty((1, 1, ol_clip_frames - 1) +
                         OL_TARGET_SHAPE[::-1]).cuda()
        complete_flag = True
        masks = []
        masks.append(start_mask)
        if start_frame_idx + ol_clip_frames <= len(seg_results):
            # train after
            frames = [
                seg_result_idx[start_frame_idx + i]
                for i in range(ol_clip_frames)
            ]
            result_idxs = [start_frame_idx + i for i in range(ol_clip_frames)]
            for i in range(1, ol_clip_frames):
                seg_result = seg_results[result_idxs[i]]
                if seg_result[0] == []:
                    complete_flag = False
                    break
                else:
                    ious = []
                    for mask in seg_result[0]:
                        iou = get_video_mIoU(start_mask, mask)
                        ious.append(iou)
                    if np.max(ious) >= 0.5:
                        mi = np.argmax(ious)
                        masks.append(seg_result[0][mi])
                    else:
                        complete_flag = False
                        break
            if complete_flag and len(masks) == ol_clip_frames:
                for i, mask in enumerate(masks):
                    Ms[:, :, i] = torch.from_numpy(mask).cuda()
                    if i != 0:
                        Ps[:, :, i - 1] = torch.from_numpy(mask).cuda()
                Fs = Fs[:, :, frames].cuda()
                optimizer.zero_grad()
                loss_video, video_mIou = Run_video_motion(model, {
                    'Fs': Fs,
                    'Ms': Ms,
                    'Ps': Ps,
                    'info': info
                },
                                                          Mem_every=1,
                                                          mode='train')
                print('finetune loss: {:.3f}, miou: {:.2f}'.format(
                    loss_video, video_mIou))
                # backward
                loss_video.backward()
                optimizer.step()
            else:
                continue

    return model
def online_learning():
    print('online learning...')
    Testset = TIANCHI_FUSAI(DATA_ROOT, imset='test.txt', target_size=OL_TARGET_SHAPE)
    print('Total test videos: {}'.format(len(Testset)))
    Testloader = data.DataLoader(Testset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)

    model = nn.DataParallel(MODEL)

    optimizer = torch.optim.Adam(model.parameters(), OL_LR, betas=(0.9, 0.99))

    print('Loading weights:', MODEL_PATH)
    model_ = torch.load(MODEL_PATH)
    if 'state_dict' in model_.keys():
        state_dict = model_['state_dict']
    else:
        state_dict = model_
    model.load_state_dict(state_dict, strict=True)

    # if 'optimizer' in model_.keys():
    #     try:
    #         optimizer.load_state_dict(model_['optimizer'])
    #     except Exception as e:
    #         print(e)

    if torch.cuda.is_available():
        model.cuda()
    # model.eval()  # turn-off BN
    model.train()
    # freeze bn
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()

    print('Start online learning...')
    progressbar = tqdm.tqdm(Testloader)

    for V in progressbar:
        F, info = V
        seq_name = info['name'][0]
        ori_shape = info['ori_shape']
        target_shape = info['target_shape']
        target_shape = (target_shape[0].cpu().numpy()[0], target_shape[1].cpu().numpy()[0])
        num_frames = info['num_frames'][0].item()
        if '_' in seq_name:
            video_name = seq_name.split('_')[0]
        else:
            video_name = seq_name
        seg_results = mask_inference(video_name, OL_TARGET_SHAPE, OL_SOLO_INTERVAL, OL_SCORE_THR)
        ol_clip_frames = OL_CLIPS

        # online learning by aug
        for _ in range(OL_ITER_PER_VIDEO):
            count = 1
            while True:
                frame = random.choice(seg_results)
                if len(frame[1]) > 0 or count >= 10:
                    break
                else:
                    count += 1
            if len(frame[1]) == 0:
                continue

            select_idx = random.choice(list(range(len(frame[0]))))
            start_mask = frame[0][select_idx][np.newaxis, np.newaxis, :, :].transpose(0, 2, 3, 1)
            score = frame[2][select_idx]
            frame_idx = frame[3]
            start_frame = F[:, :, frame_idx].cpu().numpy().transpose(0, 2, 3, 1)

            frames = []
            masks = []
            frames.append(start_frame)
            masks.append(start_mask)

            for _ in range(ol_clip_frames - 1):
                img_aug, mask_aug = ol_aug((start_frame * 255).astype(np.uint8), start_mask)
                frames.append(img_aug / 255)
                masks.append(mask_aug)
            prevs = masks[1:]
            Fs = torch.from_numpy(
                np.concatenate([f[np.newaxis, ...] for f in frames], axis=0).transpose(1, 4, 0, 2, 3)).float()
            Ms = torch.from_numpy(
                np.concatenate([m[np.newaxis, ...] for m in masks], axis=0).transpose(1, 4, 0, 2, 3)).long()
            Ps = torch.from_numpy(
                np.concatenate([p[np.newaxis, ...] for p in prevs], axis=0).transpose(1, 4, 0, 2, 3)).float()

            optimizer.zero_grad()
            loss_video, video_mIou = Run_video_sp(model, {'Fs': Fs, 'Ms': Ms, 'Ps': Ps, 'info': info},
                                                  Mem_every=1,
                                                  mode='train')
            print('aug finetune loss: {:.3f}, miou: {:.2f}'.format(loss_video, video_mIou))
            # backward
            loss_video.backward()
            optimizer.step()

        # online learning by sequence
        for _ in range(OL_ITER_PER_VIDEO):
            count = 1
            while True:
                idx, frame = random.choice(list(enumerate(seg_results)))
                if len(frame[1]) > 0 or count >= 10:
                    break
                else:
                    count += 1
            if len(frame[1]) == 0:
                continue

            seg_result_idx = [i[3] for i in seg_results]
            # start_frame_idx = np.argmax([max(i[2]) if i[2] != [] else 0 for i in seg_results])
            start_frame_idx = idx
            start_frame = frame[3]
            num_mask = len(frame[0])
            start_mask = frame[0][random.choice(range(num_mask))].astype(np.uint8)

            Ms = torch.empty((1, 1, ol_clip_frames) + OL_TARGET_SHAPE[::-1]).cuda().long()
            Ps = torch.empty((1, 1, ol_clip_frames - 1) + OL_TARGET_SHAPE[::-1]).cuda()
            complete_flag = True
            masks = []
            masks.append(start_mask)
            if start_frame_idx + ol_clip_frames <= len(seg_results):
                # train after
                frames = [seg_result_idx[start_frame_idx + i] for i in range(ol_clip_frames)]
                result_idxs = [start_frame_idx + i for i in range(ol_clip_frames)]
                for i in range(1, ol_clip_frames):
                    seg_result = seg_results[result_idxs[i]]
                    if seg_result[0] == []:
                        complete_flag = False
                        break
                    else:
                        ious = []
                        for mask in seg_result[0]:
                            iou = get_video_mIoU(masks[-1], mask)
                            ious.append(iou)
                        if np.max(ious) >= 0.5:
                            mi = np.argmax(ious)
                            masks.append(seg_result[0][mi])
                        else:
                            complete_flag = False
                            break
                if complete_flag and len(masks) == ol_clip_frames:
                    for i, mask in enumerate(masks):
                        Ms[:, :, i] = torch.from_numpy(mask).cuda()
                        if i != 0:
                            Ps[:, :, i - 1] = torch.from_numpy(mask).cuda()
                    Fs = F[:, :, frames].cuda()
                    optimizer.zero_grad()
                    loss_video, video_mIou = Run_video_sp(model, {'Fs': Fs, 'Ms': Ms, 'Ps': Ps, 'info': info},
                                                          Mem_every=1,
                                                          mode='train')
                    print('sequence finetune loss: {:.3f}, miou: {:.2f}'.format(loss_video, video_mIou))
                    # backward
                    loss_video.backward()
                    optimizer.step()
                else:
                    continue

    return model
Beispiel #5
0
def Run_video_enhanced_motion(model,
                              batch,
                              Mem_every=None,
                              Mem_number=None,
                              mode='train'):
    Fs, Ms, info = batch['Fs'], batch['Ms'], batch['info']
    num_frames = info['num_frames'][0].item()
    intervals = info['intervals']
    if Mem_every:
        to_memorize = [
            int(i) for i in np.arange(0, num_frames, step=Mem_every)
        ]
    elif Mem_number:
        to_memorize = [
            int(round(i))
            for i in np.linspace(0, num_frames, num=Mem_number + 2)[:-1]
        ]
    else:
        raise NotImplementedError

    b, c, f, h, w = Fs.shape
    Es = torch.zeros(
        (b, 1, f, h, w)).float().cuda()  # [1,1,50,480,864][b,c,t,h,w]
    Es[:, :, 0] = Ms[:, :, 0]

    loss_video = torch.tensor(0.0).cuda()
    loss_total = torch.tensor(0.0).cuda()

    Os = torch.zeros((b, c, int(h / 4), int(w / 4)))
    first_frame = Fs[:, :, 0].detach()
    first_mask = Ms[:, :, 0].detach()
    first_frame = first_frame * first_mask.repeat(1, 3, 1, 1).type(torch.float)
    for i in range(b):
        mask_ = first_mask[i]
        mask_ = mask_.squeeze(0).cpu().numpy().astype(np.uint8)
        assert np.any(mask_)
        x, y, w_, h_ = cv2.boundingRect(mask_)
        # c_x = x + w_ / 2
        # c_y = y + h_ / 2
        # c_x = np.clip(c_x, h / 8, 7 * h / 8)
        # c_y = np.clip(c_y, w / 8, 7 * w / 8)
        patch = first_frame[i, :, y:(y + h_), x:(x + w_)].cpu().numpy()
        patch = patch.transpose(1, 2, 0)
        # patch = cv2.resize(patch, (template_size, template_size))
        # patch = patch.transpose(2, 1, 0)
        patch = cv2.resize(patch, (int(w / 4), int(h / 4)))
        patch = patch.transpose(2, 0, 1)
        patch = torch.from_numpy(patch)
        Os[i, :, :, :] = patch

    for t in range(1, num_frames):
        interval = intervals[t][0].item()
        if mode == 'train':
            if interval != 1:
                model.module.Memory.eval()
            else:
                model.module.Memory.train()
        # memorize
        pre_key, pre_value = model([Fs[:, :, t - 1], Es[:, :, t - 1]])
        pre_key = pre_key.unsqueeze(2)
        pre_value = pre_value.unsqueeze(2)

        if t - 1 == 0:  # the first frame
            this_keys_m, this_values_m = pre_key, pre_value
        else:  # other frame
            this_keys_m = torch.cat([keys, pre_key], dim=2)
            this_values_m = torch.cat([values, pre_value], dim=2)

        # segment
        prev_mask = torch.round(Es[:, :, t - 1].detach()).float()
        logits, p_m2, p_m3 = model(
            [Fs[:, :, t], Os, this_keys_m, this_values_m, prev_mask])
        em = F.softmax(logits, dim=1)[:, 1]  # B h w
        Es[:, 0, t] = em

        #  calculate loss on cuda
        if mode == 'train' or mode == 'val':
            Ms_cuda = Ms[:, 0, t].cuda()
            loss_video += _loss(logits, Ms_cuda) + 0.5 * _loss(
                p_m2, Ms_cuda) + 0.25 * _loss(p_m3, Ms_cuda)
            loss_total = loss_video

        # update key and value
        if t - 1 in to_memorize:
            keys, values = this_keys_m, this_values_m

    #  calculate mIOU on cuda
    pred = torch.round(Es.float().cuda())
    if mode == 'train' or mode == 'val':
        video_mIoU = 0
        for n in range(len(Ms)):  # Nth batch
            video_mIoU = video_mIoU + get_video_mIoU(
                pred[n],
                Ms[n].cuda())  # mIOU of video(t frames) for each batch
        video_mIoU = video_mIoU / len(Ms)  # mean IoU among batch

        return loss_total / num_frames, video_mIoU

    elif mode == 'test':
        return pred, Es
Beispiel #6
0
def Run_video_motion(model,
                     batch,
                     Mem_every=None,
                     Mem_number=None,
                     mode='train'):
    Fs, Ms, info = batch['Fs'], batch['Ms'], batch['info']
    num_frames = Fs.shape[2]
    # intervals = info['intervals']
    if Mem_every:
        to_memorize = [
            int(i) for i in np.arange(0, num_frames, step=Mem_every)
        ]
    elif Mem_number:
        to_memorize = [
            int(round(i))
            for i in np.linspace(0, num_frames, num=Mem_number + 2)[:-1]
        ]
    else:
        raise NotImplementedError

    B, _, f, H, W = Fs.shape
    Es = torch.zeros(
        (B, 1, f, H, W)).float().cuda()  # [1,1,50,480,864][b,c,t,h,w]
    Es[:, :, 0] = Ms[:, :, 0]

    loss_video = torch.tensor(0.0).cuda()

    for t in range(1, num_frames):
        # interval = intervals[t][0].item()
        # if mode == 'train':
        #     if interval != 1:
        #         model.module.Memory.eval()
        #     else:
        #         model.module.Memory.train()
        # memorize
        pre_key, pre_value = model([Fs[:, :, t - 1], Es[:, :, t - 1]])
        pre_key = pre_key.unsqueeze(2)
        pre_value = pre_value.unsqueeze(2)

        if t - 1 == 0:  # the first frame
            this_keys_m, this_values_m = pre_key, pre_value
        else:  # other frame
            this_keys_m = torch.cat([keys, pre_key], dim=2)
            this_values_m = torch.cat([values, pre_value], dim=2)

        # segment
        prev_mask = torch.round(Es[:, :, t - 1].detach()).float()
        logits, p_m2, p_m3 = model(
            [Fs[:, :, t], this_keys_m, this_values_m, prev_mask])
        em = F.softmax(logits, dim=1)[:, 1]  # B h w
        Es[:, 0, t] = em

        #  calculate loss on cuda
        if mode == 'train' or mode == 'val':
            Ms_cuda = Ms[:, 0, t].cuda()
            loss_video += _loss(logits, Ms_cuda) + 0.5 * _loss(
                p_m2, Ms_cuda) + 0.25 * _loss(p_m3, Ms_cuda)

        # update key and value
        if t - 1 in to_memorize:
            keys, values = this_keys_m, this_values_m

    #  calculate mIOU on cuda
    pred = torch.round(Es.float().cuda())
    if mode == 'train' or mode == 'val':
        video_mIoU = 0
        for n in range(len(Ms)):  # Nth batch
            video_mIoU = video_mIoU + get_video_mIoU(
                pred[n],
                Ms[n].cuda())  # mIOU of video(t frames) for each batch
        video_mIoU = video_mIoU / len(Ms)  # mean IoU among batch

        return loss_video / num_frames, video_mIoU

    elif mode == 'test':
        return pred, Es
Beispiel #7
0
def Run_video_enhanced_varysize(model,
                                batch,
                                Mem_every=None,
                                Mem_number=None,
                                mode='train'):
    Fs, Ms, info = batch['Fs'], batch['Ms'], batch['info']
    num_frames = info['num_frames'][0].item()
    if Mem_every:
        to_memorize = [
            int(i) for i in np.arange(0, num_frames, step=Mem_every)
        ]
    elif Mem_number:
        to_memorize = [
            int(round(i))
            for i in np.linspace(0, num_frames, num=Mem_number + 2)[:-1]
        ]
    else:
        raise NotImplementedError

    b, c, f, h, w = Fs.shape
    Es = torch.zeros(
        (b, 1, f, h, w)).float().cuda()  # [1,1,50,480,864][b,c,t,h,w]
    Es[:, :, 0] = Ms[:, :, 0]

    os = []
    first_frame = Fs[:, :, 0].detach()
    first_mask = Ms[:, :, 0].detach()
    first_frame = first_frame * first_mask.repeat(1, 3, 1, 1).type(torch.float)
    for i in range(b):
        mask_ = first_mask[i]
        mask_ = mask_.squeeze(0).cpu().numpy().astype(np.uint8)
        assert np.any(mask_)
        x, y, w_, h_ = cv2.boundingRect(mask_)
        patch = first_frame[i, :, y:(y + h_), x:(x + w_)].cpu().numpy()
        Os = torch.zeros((1, c, h_, w_))
        patch = patch.transpose(1, 2, 0)
        patch = patch.transpose(2, 0, 1)
        patch = torch.from_numpy(patch)
        Os[0, :, :, :] = patch
        os.append(Os)

    loss_video = torch.tensor(0.0).cuda()

    for t in range(1, num_frames):
        # memorize
        pre_key, pre_value = model([Fs[:, :, t - 1], Es[:, :, t - 1]])
        pre_key = pre_key.unsqueeze(2)
        pre_value = pre_value.unsqueeze(2)

        if t - 1 == 0:  # the first frame
            this_keys_m, this_values_m = pre_key, pre_value
        else:  # other frame
            this_keys_m = torch.cat([keys, pre_key], dim=2)
            this_values_m = torch.cat([values, pre_value], dim=2)

        # segment
        logits, p_m2, p_m3 = model(
            [Fs[:, :, t], os, this_keys_m, this_values_m])  # B 2 h w
        em = F.softmax(logits, dim=1)[:, 1]  # B h w
        Es[:, 0, t] = em

        # update key and value
        if t - 1 in to_memorize:
            keys, values = this_keys_m, this_values_m

        #  calculate loss on cuda
        if mode == 'train' or mode == 'val':
            Ms_cuda = Ms[:, 0, t].cuda()
            loss_video += (_loss(logits, Ms_cuda) +
                           0.5 * _loss(p_m2, Ms_cuda) +
                           0.25 * _loss(p_m3, Ms_cuda))

    #  calculate mIOU on cuda
    pred = torch.round(Es.float().cuda())
    if mode == 'train' or mode == 'val':
        video_mIoU = 0
        for n in range(len(Ms)):  # Nth batch
            video_mIoU = video_mIoU + get_video_mIoU(
                pred[n],
                Ms[n].float().cuda())  # mIOU of video(t frames) for each batch
        video_mIoU = video_mIoU / len(Ms)  # mean IoU among batch

        return loss_video / num_frames, video_mIoU

    elif mode == 'test':
        return pred, Es
Beispiel #8
0
def Run_video_hkf(model, batch, Mem_every=1, Mem_number=None, mode='train'):
    Fs, Ms, info = batch['Fs'], batch['Ms'], batch['info']
    # if random.random() < 0.5:
    #     Fs = Fs[:, :, ::-1, ...]
    #     Ms = Ms[:, :, ::-1, ...]
    num_frames = info['num_frames'][0].item()
    # if Mem_every:
    #     to_memorize = [int(i) for i in np.arange(0, num_frames, step=Mem_every)]
    # elif Mem_number:
    #     to_memorize = [int(round(i)) for i in np.linspace(0, num_frames, num=Mem_number + 2)[:-1]]
    # else:
    #     raise NotImplementedError

    B, _, f, H, W = Fs.shape
    Es = torch.zeros(
        (B, 1, f, H, W)).float().cuda()  # [1,1,50,480,864][b,c,t,h,w]
    Es[:, :, 0] = Ms[:, :, 0]

    loss_video = torch.tensor(0.0).cuda()
    loss_total = torch.tensor(0.0).cuda()

    for t in range(1, num_frames):
        # memorize
        # pre_key, pre_value = model([Fs[:, :, t - 1], Es[:, :, t - 1]])
        # pre_key = pre_key.unsqueeze(2)
        # pre_value = pre_value.unsqueeze(2)
        #
        # if t - 1 == 0:  # the first frame
        #     this_keys_m, this_values_m = pre_key, pre_value
        # else:  # other frame
        #     this_keys_m = torch.cat([keys, pre_key], dim=2)
        #     this_values_m = torch.cat([values, pre_value], dim=2)

        # segment
        logits, p_m2, p_m3 = model(
            [Fs[:, :, t], Fs[:, :, t - 1], Es[:, :, t - 1]])
        em = F.softmax(logits, dim=1)[:, 1]  # B h w
        Es[:, 0, t] = em

        #  calculate loss on cuda
        if mode == 'train' or mode == 'val':
            Ms_cuda = Ms[:, 0, t].cuda()
            loss_video += _loss(logits, Ms_cuda) + 0.5 * _loss(
                p_m2, Ms_cuda) + 0.25 * _loss(p_m3, Ms_cuda)
            loss_total = loss_video

        # update key and value
        # if t - 1 in to_memorize:
        #     keys, values = this_keys_m, this_values_m

    #  calculate mIOU on cuda
    pred = torch.round(Es.float().cuda())
    if mode == 'train' or mode == 'val':
        video_mIoU = 0
        for n in range(len(Ms)):  # Nth batch
            video_mIoU = video_mIoU + get_video_mIoU(
                pred[n],
                Ms[n].cuda())  # mIOU of video(t frames) for each batch
        video_mIoU = video_mIoU / len(Ms)  # mean IoU among batch

        return loss_total / num_frames, video_mIoU

    elif mode == 'test':
        return pred, Es
def Run_video(model,
              Fs,
              Ms,
              num_frames,
              solo_results=None,
              Mem_every=None,
              Mem_number=None,
              mode='train'):
    if Mem_every:
        to_memorize = [
            int(i) for i in np.arange(0, num_frames, step=Mem_every)
        ]
    elif Mem_number:
        to_memorize = [
            int(round(i))
            for i in np.linspace(0, num_frames, num=Mem_number + 2)[:-1]
        ]
    else:
        raise NotImplementedError

    B, _, f, H, W = Fs.shape
    Es = torch.zeros(
        (B, 1, f, H, W)).float().cuda()  # [1,1,50,480,864][b,c,t,h,w]
    Es[:, :, 0] = Ms[:, :, 0]

    loss_video = torch.tensor(0.0).cuda()
    loss_total = torch.tensor(0.0).cuda()

    for t in range(1, num_frames):
        # memorize
        pre_key, pre_value = model([Fs[:, :, t - 1], Es[:, :, t - 1]])
        pre_key = pre_key.unsqueeze(2)
        pre_value = pre_value.unsqueeze(2)

        Sm = torch.zeros_like(Es[:, :, 0])
        #process solo result
        st = time.time()
        for b in range(B):
            if mode == 'train':
                gt = Ms[b, :, t]
            else:
                gt = Es[b, :, t - 1]
                gt = torch.round(gt)
            solo = solo_results[b]
            if len(solo) == 0:
                m_ = torch.zeros_like(gt)
            else:
                masks = solo[t][0]
                if masks is not None:
                    ious = []
                    for mask in masks:
                        iou = get_video_mIoU(gt, mask)
                        ious.append(iou)
                    ious = np.array(ious)
                    if np.any(ious >= 0.7):
                        idx = np.argmax(ious)
                        m_ = torch.from_numpy(masks[idx]).cuda()
                    else:
                        m_ = torch.zeros_like(gt)
                else:
                    m_ = torch.zeros_like(gt)
            if len(m_.shape) == 2:
                m_ = m_.unsqueeze(0)
            Sm[b] = m_
        ed = time.time()
        print('Cal IOU time cost: {:.2f}s'.format(ed - st))

        if t - 1 == 0:  # the first frame
            this_keys_m, this_values_m = pre_key, pre_value
        else:  # other frame
            this_keys_m = torch.cat([keys, pre_key], dim=2)
            this_values_m = torch.cat([values, pre_value], dim=2)

        # segment
        logits, p_m2, p_m3 = model(
            [Fs[:, :, t], this_keys_m, this_values_m,
             Sm.detach()])  # B 2 h w
        em = F.softmax(logits, dim=1)[:, 1]  # B h w
        Es[:, 0, t] = em

        #  calculate loss on cuda
        if mode == 'train' or mode == 'val':
            Ms_cuda = Ms[:, 0, t].cuda()
            loss_video += _loss(logits, Ms_cuda) + 0.5 * _loss(
                p_m2, Ms_cuda) + 0.25 * _loss(p_m3, Ms_cuda)
            loss_total = loss_video

        # update key and value
        if t - 1 in to_memorize:
            keys, values = this_keys_m, this_values_m
            # keys, values = this_keys_m.detach(), this_values_m.detach()

    #  calculate mIOU on cuda
    pred = torch.round(Es.float().cuda())
    if mode == 'train' or mode == 'val':
        video_mIoU = 0
        for n in range(len(Ms)):  # Nth batch
            video_mIoU = video_mIoU + get_video_mIoU(
                pred[n],
                Ms[n].cuda())  # mIOU of video(t frames) for each batch
        video_mIoU = video_mIoU / len(Ms)  # mean IoU among batch

        return loss_total / num_frames, video_mIoU

    elif mode == 'test':
        return pred, Es
Beispiel #10
0
def Run_video(model, Fs, seg_results, instance_num):
    instances = {}
    for result in seg_results:
        if len(result[1]):
            for id in result[1]:
                if id == 0:
                    continue
                instances.setdefault(id, 0)
                instances[id] += 1

    instance_num = min(instance_num, len(instances))
    instances_ = np.array(ins)

    seg_result_idx = [i[3] for i in seg_results]

    instance_idx = 1
    b, c, T, h, w = Fs.shape
    results = []

    while True:

        start_frame_idx = np.argmax(
            [max(i[2]) if i[2] != [] else 0 for i in seg_results])
        start_frame = seg_result_idx[start_frame_idx]
        start_mask = seg_results[start_frame_idx][0][0].astype(np.uint8)
        # start_mask = cv2.resize(start_mask, (w, h))
        start_mask = torch.from_numpy(start_mask)

        Es = torch.zeros((b, 1, T, h, w)).float()
        Es[:, :, start_frame] = start_mask
        # to_memorize = [int(i) for i in np.arange(start_frame, num_frames, step=Mem_every)]
        to_memorize = [start_frame]
        for t in range(start_frame + 1, num_frames):  # frames after
            # memorize
            pre_key, pre_value = model([Fs[:, :, t - 1], Es[:, :, t - 1]])
            pre_key = pre_key.unsqueeze(2)
            pre_value = pre_value.unsqueeze(2)

            if t - 1 == start_frame:  # the first frame
                this_keys_m, this_values_m = pre_key, pre_value
            else:  # other frame
                this_keys_m = torch.cat([keys, pre_key], dim=2)
                this_values_m = torch.cat([values, pre_value], dim=2)

            # segment
            logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m])

            em = F.softmax(logits, dim=1)[:, 1]  # B h w
            Es[:, 0, t] = em

            # check solo result
            pred = torch.round(em.float())
            if t in seg_result_idx:
                idx = seg_result_idx.index(t)
                this_frame_results = seg_results[idx]
                masks = this_frame_results[0]
                ious = []
                for mask in masks:
                    mask = mask.astype(np.uint8)
                    mask = torch.from_numpy(mask)
                    iou = get_video_mIoU(pred, mask)
                    ious.append(iou)
                if ious != []:
                    ious = np.array(ious)
                    reserve = list(range(len(ious)))
                    if sum(ious >= IOU1) >= 1:
                        same_idx = np.argmax(ious)
                        mask = torch.from_numpy(masks[same_idx])
                        Es[:, 0, t] = mask
                        reserve.remove(same_idx)
                        # if abs(to_memorize[-1] - t) >= TO_MEMORY_MIN_INTERVAL:
                        to_memorize.append(t)

                    reserve_result = []
                    for n in range(3):
                        reserve_result.append(
                            [this_frame_results[n][i] for i in reserve])
                    reserve_result.append(this_frame_results[3])
                    seg_results[idx] = reserve_result

            # update key and value
            if t - 1 in to_memorize:
                keys, values = this_keys_m, this_values_m

        # to_memorize = [start_frame - int(i) for i in np.arange(0, start_frame + 1, step=Mem_every)]
        to_memorize = [start_frame]
        for t in list(range(0, start_frame))[::-1]:  # frames before
            # memorize
            pre_key, pre_value = model([Fs[:, :, t + 1], Es[:, :, t + 1]])
            pre_key = pre_key.unsqueeze(2)
            pre_value = pre_value.unsqueeze(2)

            if t + 1 == start_frame:  # the first frame
                this_keys_m, this_values_m = pre_key, pre_value
            else:  # other frame
                this_keys_m = torch.cat([keys, pre_key], dim=2)
                this_values_m = torch.cat([values, pre_value], dim=2)

            # segment
            logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m])
            em = F.softmax(logits, dim=1)[:, 1]  # B h w
            Es[:, 0, t] = em

            # check solo result
            pred = torch.round(em.float())
            if t in seg_result_idx:
                idx = seg_result_idx.index(t)
                this_frame_results = seg_results[idx]
                masks = this_frame_results[0]
                ious = []
                for mask in masks:
                    mask = mask.astype(np.uint8)
                    mask = torch.from_numpy(mask)
                    iou = get_video_mIoU(pred, mask)
                    ious.append(iou)
                if ious != []:
                    ious = np.array(ious)
                    reserve = list(range(len(ious)))
                    if sum(ious >= IOU1) >= 1:
                        same_idx = np.argmax(ious)
                        mask = torch.from_numpy(masks[same_idx])
                        Es[:, 0, t] = mask
                        reserve.remove(same_idx)
                        # if abs(to_memorize[-1] - t) >= TO_MEMORY_MIN_INTERVAL:
                        to_memorize.append(t)

                    reserve_result = []
                    for n in range(3):
                        reserve_result.append(
                            [this_frame_results[n][i] for i in reserve])
                    reserve_result.append(this_frame_results[3])
                    seg_results[idx] = reserve_result

            # update key and value
            if t + 1 in to_memorize:
                keys, values = this_keys_m, this_values_m

        for j in range(3):
            seg_results[start_frame_idx][j].pop(0)

        # pred = torch.round(Es.float())
        results.append((Es, instance_idx))

        instance_idx += 1

    return results
Beispiel #11
0
def Run_video(model, Fs, seg_results, num_frames, model_name='standard'):
    seg_result_idx = [i[3] for i in seg_results]

    instance_idx = 1
    b, c, T, h, w = Fs.shape
    results = []
    if np.all([len(i[0]) == 0 for i in seg_results]):
        print('No segmentation result of solo!')
        pred = torch.zeros((b, 1, T, h, w)).float()
        return [(pred, 1)]

    while True:
        if np.all([len(i[0]) == 0 for i in seg_results]):
            print('Run video over!')
            break
        if instance_idx > MAX_NUM:
            print('Max instance number!')
            break
        start_frame_idx = np.argmax([max(i[2]) if i[2] != [] else 0 for i in seg_results])
        start_frame = seg_result_idx[start_frame_idx]
        start_mask = seg_results[start_frame_idx][0][0].astype(np.uint8)
        # start_mask = cv2.resize(start_mask, (w, h))
        start_mask = torch.from_numpy(start_mask)

        Es = torch.zeros((b, 1, T, h, w)).float()
        Es[:, :, start_frame] = start_mask
        # to_memorize = [int(i) for i in np.arange(start_frame, num_frames, step=Mem_every)]
        to_memorize = [start_frame]
        for t in range(start_frame + 1, num_frames):  # frames after
            # memorize
            Fs = Fs.to(ipex.DEVICE)
            Es = Es.to(ipex.DEVICE)
            pre_key, pre_value = model([Fs[:, :, t - 1], Es[:, :, t - 1]])
            pre_key = pre_key.unsqueeze(2)
            pre_value = pre_value.unsqueeze(2)

            if t - 1 == start_frame:  # the first frame
                this_keys_m, this_values_m = pre_key, pre_value
            else:  # other frame
                this_keys_m = torch.cat([keys, pre_key], dim=2)
                this_values_m = torch.cat([values, pre_value], dim=2)

            # segment
            if model_name == 'sp':
                Fs = Fs.to(ipex.DEVICE)
                Es = Es.to(ipex.DEVICE)
                this_keys_m = this_keys_m.to(ipex.DEVICE)
                this_values_m = this_values_m.to(ipex.DEVICE)

                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m, torch.round(Es[:, :, t - 1])])
            elif model_name == 'standard':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m])
            else:
                raise NotImplementedError
            em = F.softmax(logits, dim=1)[:, 1]  # B h w
            Es[:, 0, t] = em

            # check solo result
            pred = torch.round(em.float())
            if t in seg_result_idx:
                idx = seg_result_idx.index(t)
                this_frame_results = seg_results[idx]
                masks = this_frame_results[0]
                ious = []
                for mask in masks:
                    mask = mask.astype(np.uint8)
                    mask = torch.from_numpy(mask)
                    iou = get_video_mIoU(pred, mask)
                    ious.append(iou)
                if ious != []:
                    ious = np.array(ious)
                    reserve = list(range(len(ious)))
                    if sum(ious >= IOU1) >= 1:
                        same_idx = np.argmax(ious)
                        mask = torch.from_numpy(masks[same_idx])
                        # if get_video_mIoU(mask, torch.round(Es[:, 0, t - 1])) \
                        #     > get_video_mIoU(pred, torch.round(Es[:, 0, t - 1])):
                        Es[:, 0, t] = mask
                        reserve.remove(same_idx)
                        # if abs(to_memorize[-1] - t) >= TO_MEMORY_MIN_INTERVAL:
                        to_memorize.append(t)

                    # for i, iou in enumerate(ious):
                    #     if iou >= IOU2:
                    #         if i in reserve:
                    #             reserve.remove(i)

                    reserve_result = []
                    for n in range(3):
                        reserve_result.append([this_frame_results[n][i] for i in reserve])
                    reserve_result.append(this_frame_results[3])
                    seg_results[idx] = reserve_result

            # update key and value
            if t - 1 in to_memorize:
                keys, values = this_keys_m, this_values_m

        # to_memorize = [start_frame - int(i) for i in np.arange(0, start_frame + 1, step=Mem_every)]
        to_memorize = [start_frame]
        for t in list(range(0, start_frame))[::-1]:  # frames before
            # memorize
            Fs = Fs.to(ipex.DEVICE)
            Es = Es.to(ipex.DEVICE)
            pre_key, pre_value = model([Fs[:, :, t + 1], Es[:, :, t + 1]])
            pre_key = pre_key.unsqueeze(2)
            pre_value = pre_value.unsqueeze(2)

            if t + 1 == start_frame:  # the first frame
                this_keys_m, this_values_m = pre_key, pre_value
            else:  # other frame
                this_keys_m = torch.cat([keys, pre_key], dim=2)
                this_values_m = torch.cat([values, pre_value], dim=2)

            # segment
            if model_name == 'sp':
                Fs = Fs.to(ipex.DEVICE)
                Es = Es.to(ipex.DEVICE)
                this_keys_m = this_keys_m.to(ipex.DEVICE)
                this_values_m = this_values_m.to(ipex.DEVICE)
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m, torch.round(Es[:, :, t + 1])])
            elif model_name == 'standard':
                logits, _, _ = model([Fs[:, :, t], this_keys_m, this_values_m])
            else:
                raise NotImplementedError
            em = F.softmax(logits, dim=1)[:, 1]  # B h w
            Es[:, 0, t] = em

            # check solo result
            pred = torch.round(em.float())
            if t in seg_result_idx:
                idx = seg_result_idx.index(t)
                this_frame_results = seg_results[idx]
                masks = this_frame_results[0]
                ious = []
                for mask in masks:
                    mask = mask.astype(np.uint8)
                    mask = torch.from_numpy(mask)
                    iou = get_video_mIoU(pred, mask)
                    ious.append(iou)
                if ious != []:
                    ious = np.array(ious)
                    reserve = list(range(len(ious)))
                    if sum(ious >= IOU1) >= 1:
                        same_idx = np.argmax(ious)
                        mask = torch.from_numpy(masks[same_idx])
                        # if get_video_mIoU(mask, torch.round(Es[:, 0, t + 1])) \
                        #         > get_video_mIoU(pred, torch.round(Es[:, 0, t + 1])):
                        Es[:, 0, t] = mask
                        reserve.remove(same_idx)
                        # if abs(to_memorize[-1] - t) >= TO_MEMORY_MIN_INTERVAL:
                        to_memorize.append(t)

                    # for i, iou in enumerate(ious):
                    #     if iou >= IOU2:
                    #         if i in reserve:
                    #             reserve.remove(i)

                    reserve_result = []
                    for n in range(3):
                        reserve_result.append([this_frame_results[n][i] for i in reserve])
                    reserve_result.append(this_frame_results[3])
                    seg_results[idx] = reserve_result

            # update key and value
            if t + 1 in to_memorize:
                keys, values = this_keys_m, this_values_m

        for j in range(3):
            seg_results[start_frame_idx][j].pop(0)

        # pred = torch.round(Es.float())
        results.append((Es, instance_idx))

        instance_idx += 1

    return results