def _model(model_name): if model_name == 'motion': from STM.models.model_fusai import STM model = STM() elif model_name == 'aspp': from STM.models.model_fusai_aspp import STM model = STM() elif model_name == 'enhanced': from STM.models.model_enhanced import STM model = STM() elif model_name == 'enhanced_motion': from STM.models.model_enhanced_motion import STM model = STM() elif model_name == 'standard': from STM.models.model import STM model = STM() elif model_name == 'varysize': from STM.models.model_enhanced_varysize import STM model = STM() elif model_name == 'sp': from STM.models.model_fusai_spatial_prior import STM model = STM() else: raise ValueError return model
def main(data_root, model_path, palette): # Model and version MODEL = 'STM' print(MODEL, ': Testing on TIANCHI...') if torch.cuda.is_available(): print('using Cuda devices, num:', torch.cuda.device_count()) Testset = TIANCHI(data_root, imset='test.txt', single_object=True) print('Total test videos: {}'.format(len(Testset))) Testloader = data.DataLoader(Testset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) model = nn.DataParallel(STM()) if torch.cuda.is_available(): model.cuda() model.eval() # turn-off BN print('Loading weights:', model_path) model_ = torch.load(model_path) if 'state_dict' in model_.keys(): state_dict = model_['state_dict'] else: state_dict = model_ model.load_state_dict(state_dict) code_name = 'tianchi' # date = datetime.datetime.strftime(datetime.datetime.now(), '%y%m%d%H%M') print('Start Testing:', code_name) for seq, V in enumerate(Testloader): Fs, Ms, info = V seq_name = info['name'][0] ori_shape = info['ori_shape'] num_frames = info['num_frames'][0].item() if '_' in seq_name: video_name = seq_name.split('_')[0] print('[{}]: num_frames: {}'.format(seq_name, num_frames)) pred, Es = Run_video(model, Fs, Ms, num_frames, Mem_every=5, Mem_number=None, mode='test') # Save results for quantitative eval ###################### test_path = os.path.join(TMP_PATH, seq_name) score_path = os.path.join(SCORE_PATH, seq_name) if not os.path.exists(test_path): os.makedirs(test_path) if not os.path.exists(score_path): os.makedirs(score_path) for f in range(num_frames): img_E = Image.fromarray(pred[0, 0, f].cpu().numpy().astype(np.uint8)) img_E.putpalette(palette) img_E = img_E.resize(ori_shape[::-1]) img_E.save( os.path.join(test_path, '{}.png'.format(VIDEO_FRAMES[video_name][f]))) score = Es[0, 0, f].cpu().numpy() * 255 score = score.astype(np.uint8) score = cv2.resize(score, tuple(ori_shape[::-1])) cv2.imwrite( os.path.join(score_path, '{}.jpg'.format(VIDEO_FRAMES[video_name][f])), score)
def vos_infer(data_root, model_path, palette): # Model and version MODEL = 'STM' print(MODEL, ': Testing on TIANCHI...') if torch.cuda.is_available(): print('using Cuda devices, num:', torch.cuda.device_count()) Testset = TIANCHI(data_root, imset='test.txt', single_object=True) print('Total test videos: {}'.format(len(Testset))) Testloader = data.DataLoader(Testset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) model = nn.DataParallel(STM()) if torch.cuda.is_available(): model.cuda() model.eval() # turn-off BN print('Loading weights:', model_path) model_ = torch.load(model_path) if 'state_dict' in model_.keys(): state_dict = model_['state_dict'] else: state_dict = model_ model.load_state_dict(state_dict) code_name = 'tianchi' # date = datetime.datetime.strftime(datetime.datetime.now(), '%y%m%d%H%M') print('Start Testing:', code_name) for seq, V in enumerate(Testloader): if len(V) == 3: Fs, Ms, info = V seq_name = info['name'][0] ori_shape = info['ori_shape'] num_frames = info['num_frames'][0].item() mode = info['mode'] if '_' in seq_name: video_name = seq_name.split('_')[0] if mode == 0: frame_list = VIDEO_FRAMES[video_name] else: frame_list = VIDEO_FRAMES[video_name][::-1] print('[{}]: num_frames: {}'.format(seq_name, num_frames)) pred, Es = Run_video(model, Fs, Ms, num_frames, Mem_every=5, Mem_number=None, mode='test') # Save results for quantitative eval ###################### test_path = os.path.join(TMP_PATH, seq_name) if not os.path.exists(test_path): os.makedirs(test_path) for f in range(num_frames): img_E = Image.fromarray(pred[0, 0, f].cpu().numpy().astype(np.uint8)) img_E.putpalette(palette) img_E = img_E.resize(ori_shape[::-1]) img_E.save(os.path.join(test_path, '{}.png'.format(frame_list[f]))) elif len(V) == 4: print('Start at middle frame!') Fs_p, Fs_r, Ms, info = V seq_name = info['name'][0] ori_shape = info['ori_shape'] num_frames = info['num_frames'][0].item() start_index = info['start_index'] if '_' in seq_name: video_name = seq_name.split('_')[0] _, _, prev_frame_num, _, _ = Fs_p.shape _, _, rear_frame_num, _, _ = Fs_r.shape prev_frame_list = VIDEO_FRAMES[video_name][:start_index+1][::-1] rear_frame_list = VIDEO_FRAMES[video_name][start_index:] print('[{}]: num_frames: {}'.format(seq_name, num_frames)) pred, Es = Run_video(model, Fs_p, Ms, prev_frame_num, Mem_every=5, Mem_number=None, mode='test') pred_r, Es_r = Run_video(model, Fs_r, Ms, rear_frame_num, Mem_every=5, Mem_number=None, mode='test') # Save results for quantitative eval ###################### test_path = os.path.join(TMP_PATH, seq_name) if not os.path.exists(test_path): os.makedirs(test_path) for f in range(prev_frame_num): img_E = Image.fromarray(pred[0, 0, f].cpu().numpy().astype(np.uint8)) img_E.putpalette(palette) img_E = img_E.resize(ori_shape[::-1]) img_E.save(os.path.join(test_path, '{}.png'.format(prev_frame_list[f]))) for f in range(rear_frame_num): img_E = Image.fromarray(pred_r[0, 0, f].cpu().numpy().astype(np.uint8)) img_E.putpalette(palette) img_E = img_E.resize(ori_shape[::-1]) img_E.save(os.path.join(test_path, '{}.png'.format(rear_frame_list[f])))
def vos_inference(): # Model and version MODEL = 'STM' print(MODEL, ': Testing on TIANCHI...') if torch.cuda.is_available(): print('using Cuda devices, num:', torch.cuda.device_count()) Testset = TIANCHI_FUSAI(DATA_ROOT, imset='test.txt', target_size=TARGET_SHAPE) print('Total test videos: {}'.format(len(Testset))) Testloader = data.DataLoader(Testset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) model = nn.DataParallel(STM()) if torch.cuda.is_available(): model.cuda() model.eval() # turn-off BN print('Loading weights:', MODEL_PATH) model_ = torch.load(MODEL_PATH) if 'state_dict' in model_.keys(): state_dict = model_['state_dict'] else: state_dict = model_ model.load_state_dict(state_dict) code_name = 'Tianchi fusai' # date = datetime.datetime.strftime(datetime.datetime.now(), '%y%m%d%H%M') print('Start Testing:', code_name) progressbar = tqdm.tqdm(Testloader) for V in progressbar: Fs, info = V seq_name = info['name'][0] ori_shape = info['ori_shape'] target_shape = info['target_shape'] target_shape = (target_shape[0].cpu().numpy()[0], target_shape[1].cpu().numpy()[0]) num_frames = info['num_frames'][0].item() if '_' in seq_name: video_name = seq_name.split('_')[0] else: video_name = seq_name seg_results = mask_inference(video_name, target_shape) frame_list = VIDEO_FRAMES[video_name] print('[{}]: num_frames: {}'.format(seq_name, num_frames)) results = Run_video(model, Fs, seg_results, num_frames, Mem_every=5) for result in results: pred, instance = result test_path = os.path.join(TMP_PATH, seq_name + '_{}'.format(instance)) if not os.path.exists(test_path): os.makedirs(test_path) for f in range(num_frames): img_E = Image.fromarray(pred[0, 0, f].cpu().numpy().astype(np.uint8)) img_E.putpalette(PALETTE) img_E = img_E.resize(ori_shape[::-1]) img_E.save( os.path.join(test_path, '{}.png'.format(frame_list[f])))
# val_dataset = DAVIS(DAVIS_ROOT, phase='val', imset='tianchi_val_cf.txt', resolution='480p', # separate_instance=True, only_single=False, target_size=(864, 480)) val_dataset = TIANCHI(DAVIS_ROOT, phase='val', imset='tianchi_val_cf.txt', separate_instance=True, target_size=(864, 480), same_frames=False) val_loader = data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) model = nn.DataParallel(STM()) if torch.cuda.is_available(): model.cuda() # load weights.pth if args.load_from and not args.resume_from: print('load pretrained from:', args.load_from) model.load_state_dict(torch.load(args.load_from), strict=False) if args.mode == "val": loss_val, miou_val = validate(args, val_loader, model) log.logger.info('val loss:{}, val miou:{}'.format(loss_val, miou_val)) elif args.mode == "train": # set training para
def _model(model_name): if model_name == 'motion': from STM.models.model_fusai import STM elif model_name == 'aspp': from STM.models.model_fusai_aspp import STM model = STM() # model.eval() # model.Decoder.train() elif model_name == 'enhanced': from STM.models.model_enhanced import STM model = STM() model.eval() model.KV_Q.train() elif model_name == 'standard': from STM.models.model import STM model = STM() elif model_name == 'enhanced_motion': from STM.models.model_enhanced_motion import STM model = STM() elif model_name == 'varysize': from STM.models.model_enhanced_varysize import STM model = STM() elif model_name == 'sp': from STM.models.model_fusai_spatial_prior import STM model = STM() # model.eval() # model.Decoder.Aspp.train() elif model_name == 'hkf': from STM.model_hkf import STM model = STM() return model
def init_stm_model(model_name, model_path): if model_name == 'motion': from STM.models.model_fusai import STM model = STM() elif model_name == 'aspp': from STM.models.model_fusai_aspp import STM model = STM() elif model_name == 'enhanced': from STM.models.model_enhanced import STM model = STM() elif model_name == 'enhanced_motion': from STM.models.model_enhanced_motion import STM model = STM() elif model_name == 'standard': from STM.models.model import STM model = STM() elif model_name == 'varysize': from STM.models.model_enhanced_varysize import STM model = STM() elif model_name == 'sp': from STM.models.model_fusai_spatial_prior import STM model = STM() else: raise ValueError # turn-off BN print('Loading weights:', model_path) model_ = torch.load(model_path, map_location=torch.device('cpu')) if 'state_dict' in model_.keys(): state_dict = model_['state_dict'] else: state_dict = model_ d = {} for k, v in state_dict.items(): d.setdefault(k.replace('module.', ''), v) state_dict = d model.load_state_dict(state_dict) model.eval() model.to(ipex.DEVICE) return model