def _model(model_name):
    if model_name == 'motion':
        from STM.models.model_fusai import STM
        model = STM()
    elif model_name == 'aspp':
        from STM.models.model_fusai_aspp import STM
        model = STM()
    elif model_name == 'enhanced':
        from STM.models.model_enhanced import STM
        model = STM()
    elif model_name == 'enhanced_motion':
        from STM.models.model_enhanced_motion import STM
        model = STM()
    elif model_name == 'standard':
        from STM.models.model import STM
        model = STM()
    elif model_name == 'varysize':
        from STM.models.model_enhanced_varysize import STM
        model = STM()
    elif model_name == 'sp':
        from STM.models.model_fusai_spatial_prior import STM
        model = STM()
    else:
        raise ValueError

    return model
def main(data_root, model_path, palette):
    # Model and version
    MODEL = 'STM'
    print(MODEL, ': Testing on TIANCHI...')

    if torch.cuda.is_available():
        print('using Cuda devices, num:', torch.cuda.device_count())

    Testset = TIANCHI(data_root, imset='test.txt', single_object=True)
    print('Total test videos: {}'.format(len(Testset)))
    Testloader = data.DataLoader(Testset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 pin_memory=True)

    model = nn.DataParallel(STM())
    if torch.cuda.is_available():
        model.cuda()
    model.eval()  # turn-off BN

    print('Loading weights:', model_path)
    model_ = torch.load(model_path)
    if 'state_dict' in model_.keys():
        state_dict = model_['state_dict']
    else:
        state_dict = model_
    model.load_state_dict(state_dict)

    code_name = 'tianchi'
    # date = datetime.datetime.strftime(datetime.datetime.now(), '%y%m%d%H%M')
    print('Start Testing:', code_name)

    for seq, V in enumerate(Testloader):
        Fs, Ms, info = V
        seq_name = info['name'][0]
        ori_shape = info['ori_shape']
        num_frames = info['num_frames'][0].item()
        if '_' in seq_name:
            video_name = seq_name.split('_')[0]

        print('[{}]: num_frames: {}'.format(seq_name, num_frames))

        pred, Es = Run_video(model,
                             Fs,
                             Ms,
                             num_frames,
                             Mem_every=5,
                             Mem_number=None,
                             mode='test')

        # Save results for quantitative eval ######################
        test_path = os.path.join(TMP_PATH, seq_name)
        score_path = os.path.join(SCORE_PATH, seq_name)
        if not os.path.exists(test_path):
            os.makedirs(test_path)
        if not os.path.exists(score_path):
            os.makedirs(score_path)
        for f in range(num_frames):
            img_E = Image.fromarray(pred[0, 0,
                                         f].cpu().numpy().astype(np.uint8))
            img_E.putpalette(palette)
            img_E = img_E.resize(ori_shape[::-1])
            img_E.save(
                os.path.join(test_path,
                             '{}.png'.format(VIDEO_FRAMES[video_name][f])))

            score = Es[0, 0, f].cpu().numpy() * 255
            score = score.astype(np.uint8)
            score = cv2.resize(score, tuple(ori_shape[::-1]))
            cv2.imwrite(
                os.path.join(score_path,
                             '{}.jpg'.format(VIDEO_FRAMES[video_name][f])),
                score)
Esempio n. 3
0
def vos_infer(data_root, model_path, palette):
    # Model and version
    MODEL = 'STM'
    print(MODEL, ': Testing on TIANCHI...')

    if torch.cuda.is_available():
        print('using Cuda devices, num:', torch.cuda.device_count())

    Testset = TIANCHI(data_root, imset='test.txt', single_object=True)
    print('Total test videos: {}'.format(len(Testset)))
    Testloader = data.DataLoader(Testset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)

    model = nn.DataParallel(STM())
    if torch.cuda.is_available():
        model.cuda()
    model.eval()  # turn-off BN

    print('Loading weights:', model_path)
    model_ = torch.load(model_path)
    if 'state_dict' in model_.keys():
        state_dict = model_['state_dict']
    else:
        state_dict = model_
    model.load_state_dict(state_dict)

    code_name = 'tianchi'
    # date = datetime.datetime.strftime(datetime.datetime.now(), '%y%m%d%H%M')
    print('Start Testing:', code_name)

    for seq, V in enumerate(Testloader):
        if len(V) == 3:
            Fs, Ms, info = V
            seq_name = info['name'][0]
            ori_shape = info['ori_shape']
            num_frames = info['num_frames'][0].item()
            mode = info['mode']
            if '_' in seq_name:
                video_name = seq_name.split('_')[0]
            if mode == 0:
                frame_list = VIDEO_FRAMES[video_name]
            else:
                frame_list = VIDEO_FRAMES[video_name][::-1]


            print('[{}]: num_frames: {}'.format(seq_name, num_frames))

            pred, Es = Run_video(model, Fs, Ms, num_frames, Mem_every=5, Mem_number=None, mode='test')

            # Save results for quantitative eval ######################
            test_path = os.path.join(TMP_PATH, seq_name)
            if not os.path.exists(test_path):
                os.makedirs(test_path)
            for f in range(num_frames):
                img_E = Image.fromarray(pred[0, 0, f].cpu().numpy().astype(np.uint8))
                img_E.putpalette(palette)
                img_E = img_E.resize(ori_shape[::-1])
                img_E.save(os.path.join(test_path, '{}.png'.format(frame_list[f])))

        elif len(V) == 4:
            print('Start at middle frame!')
            Fs_p, Fs_r, Ms, info = V
            seq_name = info['name'][0]
            ori_shape = info['ori_shape']
            num_frames = info['num_frames'][0].item()
            start_index = info['start_index']
            if '_' in seq_name:
                video_name = seq_name.split('_')[0]

            _, _, prev_frame_num, _, _ = Fs_p.shape
            _, _, rear_frame_num, _, _ = Fs_r.shape
            prev_frame_list = VIDEO_FRAMES[video_name][:start_index+1][::-1]
            rear_frame_list = VIDEO_FRAMES[video_name][start_index:]


            print('[{}]: num_frames: {}'.format(seq_name, num_frames))

            pred, Es = Run_video(model, Fs_p, Ms, prev_frame_num, Mem_every=5, Mem_number=None, mode='test')
            pred_r, Es_r = Run_video(model, Fs_r, Ms, rear_frame_num, Mem_every=5, Mem_number=None, mode='test')

            # Save results for quantitative eval ######################
            test_path = os.path.join(TMP_PATH, seq_name)
            if not os.path.exists(test_path):
                os.makedirs(test_path)
            for f in range(prev_frame_num):
                img_E = Image.fromarray(pred[0, 0, f].cpu().numpy().astype(np.uint8))
                img_E.putpalette(palette)
                img_E = img_E.resize(ori_shape[::-1])
                img_E.save(os.path.join(test_path, '{}.png'.format(prev_frame_list[f])))
            for f in range(rear_frame_num):
                img_E = Image.fromarray(pred_r[0, 0, f].cpu().numpy().astype(np.uint8))
                img_E.putpalette(palette)
                img_E = img_E.resize(ori_shape[::-1])
                img_E.save(os.path.join(test_path, '{}.png'.format(rear_frame_list[f])))
def vos_inference():
    # Model and version
    MODEL = 'STM'
    print(MODEL, ': Testing on TIANCHI...')

    if torch.cuda.is_available():
        print('using Cuda devices, num:', torch.cuda.device_count())

    Testset = TIANCHI_FUSAI(DATA_ROOT,
                            imset='test.txt',
                            target_size=TARGET_SHAPE)
    print('Total test videos: {}'.format(len(Testset)))
    Testloader = data.DataLoader(Testset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 pin_memory=True)

    model = nn.DataParallel(STM())
    if torch.cuda.is_available():
        model.cuda()
    model.eval()  # turn-off BN

    print('Loading weights:', MODEL_PATH)
    model_ = torch.load(MODEL_PATH)
    if 'state_dict' in model_.keys():
        state_dict = model_['state_dict']
    else:
        state_dict = model_
    model.load_state_dict(state_dict)

    code_name = 'Tianchi fusai'
    # date = datetime.datetime.strftime(datetime.datetime.now(), '%y%m%d%H%M')
    print('Start Testing:', code_name)
    progressbar = tqdm.tqdm(Testloader)

    for V in progressbar:
        Fs, info = V
        seq_name = info['name'][0]
        ori_shape = info['ori_shape']
        target_shape = info['target_shape']
        target_shape = (target_shape[0].cpu().numpy()[0],
                        target_shape[1].cpu().numpy()[0])
        num_frames = info['num_frames'][0].item()
        if '_' in seq_name:
            video_name = seq_name.split('_')[0]
        else:
            video_name = seq_name
        seg_results = mask_inference(video_name, target_shape)

        frame_list = VIDEO_FRAMES[video_name]

        print('[{}]: num_frames: {}'.format(seq_name, num_frames))

        results = Run_video(model, Fs, seg_results, num_frames, Mem_every=5)

        for result in results:
            pred, instance = result
            test_path = os.path.join(TMP_PATH,
                                     seq_name + '_{}'.format(instance))
            if not os.path.exists(test_path):
                os.makedirs(test_path)

            for f in range(num_frames):
                img_E = Image.fromarray(pred[0, 0,
                                             f].cpu().numpy().astype(np.uint8))
                img_E.putpalette(PALETTE)
                img_E = img_E.resize(ori_shape[::-1])
                img_E.save(
                    os.path.join(test_path, '{}.png'.format(frame_list[f])))
Esempio n. 5
0
    # val_dataset = DAVIS(DAVIS_ROOT, phase='val', imset='tianchi_val_cf.txt', resolution='480p',
    #                     separate_instance=True, only_single=False, target_size=(864, 480))
    val_dataset = TIANCHI(DAVIS_ROOT,
                          phase='val',
                          imset='tianchi_val_cf.txt',
                          separate_instance=True,
                          target_size=(864, 480),
                          same_frames=False)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 pin_memory=True)

    model = nn.DataParallel(STM())

    if torch.cuda.is_available():
        model.cuda()

    # load weights.pth
    if args.load_from and not args.resume_from:
        print('load pretrained from:', args.load_from)
        model.load_state_dict(torch.load(args.load_from), strict=False)

    if args.mode == "val":
        loss_val, miou_val = validate(args, val_loader, model)
        log.logger.info('val loss:{}, val miou:{}'.format(loss_val, miou_val))

    elif args.mode == "train":
        # set training para
Esempio n. 6
0
def _model(model_name):
    if model_name == 'motion':
        from STM.models.model_fusai import STM

    elif model_name == 'aspp':
        from STM.models.model_fusai_aspp import STM
        model = STM()
        # model.eval()
        # model.Decoder.train()
    elif model_name == 'enhanced':
        from STM.models.model_enhanced import STM
        model = STM()
        model.eval()
        model.KV_Q.train()
    elif model_name == 'standard':
        from STM.models.model import STM
        model = STM()
    elif model_name == 'enhanced_motion':
        from STM.models.model_enhanced_motion import STM
        model = STM()
    elif model_name == 'varysize':
        from STM.models.model_enhanced_varysize import STM
        model = STM()
    elif model_name == 'sp':
        from STM.models.model_fusai_spatial_prior import STM
        model = STM()
        # model.eval()
        # model.Decoder.Aspp.train()
    elif model_name == 'hkf':
        from STM.model_hkf import STM
        model = STM()

    return model
Esempio n. 7
0
def init_stm_model(model_name, model_path):
    if model_name == 'motion':
        from STM.models.model_fusai import STM
        model = STM()
    elif model_name == 'aspp':
        from STM.models.model_fusai_aspp import STM
        model = STM()
    elif model_name == 'enhanced':
        from STM.models.model_enhanced import STM
        model = STM()
    elif model_name == 'enhanced_motion':
        from STM.models.model_enhanced_motion import STM
        model = STM()
    elif model_name == 'standard':
        from STM.models.model import STM
        model = STM()
    elif model_name == 'varysize':
        from STM.models.model_enhanced_varysize import STM
        model = STM()
    elif model_name == 'sp':
        from STM.models.model_fusai_spatial_prior import STM
        model = STM()
    else:
        raise ValueError

     # turn-off BN

    print('Loading weights:', model_path)
    model_ = torch.load(model_path, map_location=torch.device('cpu'))
    if 'state_dict' in model_.keys():
        state_dict = model_['state_dict']
    else:
        state_dict = model_

    d = {}
    for k, v in state_dict.items():
        d.setdefault(k.replace('module.', ''), v)
    state_dict = d

    model.load_state_dict(state_dict)
    model.eval()
    model.to(ipex.DEVICE)

    return model