def __init__(self, use_gpu=True, time_budget=None, save_result_dir=cfg.SAVE_RESULT_DIR, pretrained=True, interactive_test=False, freeze_bn=False): self.save_res_dir = save_result_dir self.time_budget = time_budget self.feature_extracter = DeepLab(backbone='resnet', freeze_bn=freeze_bn) # self.feature_extracter= deeplabv3plus(cfg) if pretrained: pretrained_dict = torch.load(cfg.PRETRAINED_MODEL) pretrained_dict = pretrained_dict['state_dict'] self.load_network(self.feature_extracter, pretrained_dict) print('load pretrained model successfully.') self.model = IntVOS(cfg, self.feature_extracter, freeze_bn=freeze_bn) print(self.model) # self.model.freeze_bn() pd = torch.load( os.path.join(cfg.SAVE_VOS_RESULT_DIR, 'save_step_100000.pth')) self.load_network(self.model, pd) self.use_gpu = use_gpu if use_gpu: self.model = self.model.cuda()
def __init__(self, use_gpu=True,time_budget=None, save_result_dir='./saved_model_total_ytb_davis',pretrained=True,interactive_test=False,freeze_bn=False): self.save_res_dir = save_result_dir self.time_budget=time_budget self.feature_extracter = DeepLab(backbone='resnet',freeze_bn=freeze_bn) # self.feature_extracter= deeplabv3plus(cfg) if pretrained: pretrained_dict = torch.load(cfg.PRETRAINED_MODEL) pretrained_dict = pretrained_dict['state_dict'] self.load_network(self.feature_extracter,pretrained_dict) print('load pretrained model successfully.') self.model = IntVOS(cfg,self.feature_extracter,freeze_bn=freeze_bn) # self.model.freeze_bn() self.use_gpu=use_gpu if use_gpu: self.model = self.model.cuda()
def get_model(args): if args.network_name == "FPN": model = FPNSeg(args) if args.weight_type == "moco_v2": assert args.n_layers == 50, args.n_layers # path to moco_v2 weights. Current path is relative to scripts dir try: self_sup_weights = torch.load("../networks/backbones/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"] except FileNotFoundError: self_sup_weights = torch.load("../networks/backbones/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"] model_state_dict = model.encoder.state_dict() for k in list(self_sup_weights.keys()): if k.replace("module.encoder_q.", '') in ["fc.0.weight", "fc.0.bias", "fc.2.weight", "fc.2.bias"]: self_sup_weights.pop(k) for name, param in self_sup_weights.items(): name = name.replace("encoder_q.", '').replace("module", 'base') if name.replace("base.", '') in ["conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked"]: name = name.replace("base", "base.prefix") if name not in model_state_dict: print(f"{name} is not applied!") continue if isinstance(param, torch.nn.Parameter): param = param.data model_state_dict[name].copy_(param) model.encoder.load_state_dict(model_state_dict) print("moco_v2 weights are loaded successfully.") elif args.network_name == "deeplab": model = DeepLab(args) return model
print('[%s] Loading data' % (datetime.datetime.now())) # augmenters train_xtransform_us, train_ytransform_us, test_xtransform_us, test_ytransform_us = get_augmenters_2d(augment_noise=(args.augment_noise==1)) train_xtransform, train_ytransform, test_xtransform, test_ytransform = get_augmenters_2d(augment_noise=(args.augment_noise==1)) # load data train = StronglyLabeledVolumeDataset(args.data_train, args.labels_train, input_shape, transform=train_xtransform, target_transform=train_ytransform, preprocess=args.preprocess) test = StronglyLabeledVolumeDataset(args.data_test, args.labels_test, input_shape, transform=test_xtransform, target_transform=test_ytransform, preprocess=args.preprocess) train_loader = DataLoader(train, batch_size=args.train_batch_size) test_loader = DataLoader(test, batch_size=args.test_batch_size) """ Setup optimization for finetuning """ print('[%s] Setting up optimization for finetuning' % (datetime.datetime.now())) # load best checkpoint net = DeepLab() """ Train the network """ print('[%s] Training network' % (datetime.datetime.now())) net.train_net(train_loader=train_loader, test_loader=test_loader, loss_fn=loss_fn_seg, lr=args.lr, step_size=args.step_size, gamma=args.gamma, epochs=args.epochs, test_freq=args.test_freq, print_stats=args.print_stats, log_dir=args.log_dir) """ Validate the trained network """ print('[%s] Validating the trained network' % (datetime.datetime.now())) test_data = test.data
def main(): total_frame_num_dic={} ################# seqs = [] with open(os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'val' + '.txt')) as f: seqs_tmp = f.readlines() seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) seqs.extend(seqs_tmp) for seq_name in seqs: images = np.sort(os.listdir(os.path.join(cfg.DATA_ROOT, 'JPEGImages/480p/', seq_name.strip()))) total_frame_num_dic[seq_name]=len(images) _seq_list_file=os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'v_a_l' + '_instances.txt') seq_dict = json.load(open(_seq_list_file, 'r')) ################## is_save_image=True report_save_dir= cfg.RESULT_ROOT save_res_dir = './saved_model_inter_net_new_re' # Configuration used in the challenges max_nb_interactions = 8 # Maximum number of interactions max_time_per_interaction = 10000 # Maximum time per interaction per object # Total time available to interact with a sequence and an initial set of scribbles max_time = max_nb_interactions * max_time_per_interaction # Maximum time per object # Interactive parameters subset = 'val' host = 'localhost' # 'localhost' for subsets train and val. feature_extracter = DeepLab(backbone='resnet',freeze_bn=False) model = IntVOS(cfg,feature_extracter) model= model.cuda() print('model loading...') saved_model_dict = os.path.join(save_res_dir,'save_step_80000.pth') pretrained_dict = torch.load(saved_model_dict) load_network(model,pretrained_dict) print('model loading finished!') model.eval() seen_seq={} with torch.no_grad(): with DavisInteractiveSession(host=host, davis_root=cfg.DATA_ROOT, subset=subset, report_save_dir=report_save_dir, max_nb_interactions=max_nb_interactions, max_time=max_time, metric_to_optimize='J' ) as sess: while sess.next(): t_total=timeit.default_timer() # Get the current iteration scribbles sequence, scribbles, first_scribble = sess.get_scribbles(only_last=True) print(sequence) start_annotated_frame=annotated_frames(scribbles)[0] pred_masks=[] pred_masks_reverse=[] if first_scribble: anno_frame_list=[] n_interaction=1 eval_global_map_tmp_dic={} local_map_dics=({},{}) total_frame_num=total_frame_num_dic[sequence] obj_nums = seq_dict[sequence][-1] eval_data_manager=DAVIS2017_Test_Manager(split='val',root=cfg.DATA_ROOT,transform=tr.ToTensor(), seq_name=sequence) else: n_interaction+=1 ##########################Reference image process scr_f = start_annotated_frame anno_frame_list.append(start_annotated_frame) print(start_annotated_frame) scr_f = str(scr_f) while len(scr_f)!=5: scr_f='0'+scr_f ref_img = os.path.join(cfg.DATA_ROOT,'JPEGImages/480p',sequence,scr_f+'.jpg') ref_img = cv2.imread(ref_img) h_,w_ = ref_img.shape[:2] ref_img = np.array(ref_img,dtype=np.float32) #ref_img = tr.ToTensor()(ref_img) #ref_img = ref_img.unsqueeze(0) scribble_masks=scribbles2mask(scribbles,(h_,w_)) scribble_label=scribble_masks[start_annotated_frame] sample = {'ref_img':ref_img,'scribble_label':scribble_label} sample = tr.ToTensor()(sample) ref_img = sample['ref_img'] scribble_label = sample['scribble_label'] ref_img = ref_img.unsqueeze(0) scribble_label =scribble_label.unsqueeze(0) ref_img= ref_img.cuda() scribble_label = scribble_label.cuda() ###### ref_scribble_to_show = scribble_label.cpu().squeeze().numpy() im_ = Image.fromarray(ref_scribble_to_show.astype('uint8')).convert('P') im_.putpalette(_palette) ref_img_name= scr_f if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im_.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),'inter_'+ref_img_name+'.png')) #ref_frame_embedding = model.extract_feature(ref_img) if first_scribble: ref_frame_embedding = model.extract_feature(ref_img) _,channel,emb_h,emb_w = ref_frame_embedding.size() embedding_memory=torch.zeros((total_frame_num,channel,emb_h,emb_w)) embedding_memory = embedding_memory.cuda() embedding_memory[start_annotated_frame]=ref_frame_embedding[0] else: ref_frame_embedding = embedding_memory[start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) ######## if first_scribble: prev_label= None else: prev_label = os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction-1),scr_f+'.png') prev_label = Image.open(prev_label) prev_label = np.array(prev_label,dtype=np.uint8) prev_label = tr.ToTensor()({'label':prev_label}) prev_label = prev_label['label'].unsqueeze(0) prev_label = prev_label.cuda() ############### #tmp_dic, eval_global_map_tmp_dic= model.before_seghead_process(ref_frame_embedding,ref_frame_embedding, # ref_frame_embedding,scribble_label,prev_label, # normalize_nearest_neighbor_distances=True,use_local_map=False, # seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, # global_map_tmp_dic=eval_global_map_tmp_dic,frame_num=[start_annotated_frame],dynamic_seghead=model.dynamic_seghead) tmp_dic,local_map_dics = model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=scribble_label,prev_round_label=prev_label, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),frame_num=[start_annotated_frame],first_inter=first_scribble) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h_,w_),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks.append(pred_label.float()) #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(start_annotated_frame) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) ####################################### ############################## ref_prev_label = pred_label.unsqueeze(0) prev_label = pred_label.unsqueeze(0) prev_img = ref_img prev_embedding = ref_frame_embedding for ii in range(start_annotated_frame+1,total_frame_num): print('evaluating sequence:{} frame:{}'.format(sequence,ii)) sample = eval_data_manager.get_image(ii) img = sample['img'] img = img.unsqueeze(0) _,_,h,w = img.size() img = img.cuda() # current_embedding= model.extract_feature(img) if first_scribble: current_embedding= model.extract_feature(img) embedding_memory[ii]=current_embedding[0] else: current_embedding=embedding_memory[ii] current_embedding = current_embedding.unsqueeze(0) prev_img =prev_img.cuda() prev_label =prev_label.cuda() tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding, current_embedding,scribble_label,prev_label, normalize_nearest_neighbor_distances=True,use_local_map=True, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics, interaction_num=n_interaction,start_annotated_frame=start_annotated_frame, frame_num=[ii],dynamic_seghead=model.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks.append(pred_label.float()) prev_label = pred_label.unsqueeze(0) prev_img = img prev_embedding = current_embedding #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(ii) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) ####################################### prev_label = ref_prev_label prev_img = ref_img prev_embedding = ref_frame_embedding ###################################### for ii in range(start_annotated_frame): current_frame_num=start_annotated_frame-1-ii print('evaluating sequence:{} frame:{}'.format(sequence,current_frame_num)) sample = eval_data_manager.get_image(current_frame_num) img = sample['img'] img = img.unsqueeze(0) _,_,h,w = img.size() img = img.cuda() #current_embedding= model.extract_feature(img) if first_scribble: current_embedding= model.extract_feature(img) embedding_memory[current_frame_num]=current_embedding[0] else: current_embedding = embedding_memory[current_frame_num] current_embedding = current_embedding.unsqueeze(0) prev_img =prev_img.cuda() prev_label =prev_label.cuda() tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding, current_embedding,scribble_label,prev_label, normalize_nearest_neighbor_distances=True,use_local_map=True, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,frame_num=[current_frame_num],dynamic_seghead=model.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks_reverse.append(pred_label.float()) prev_label = pred_label.unsqueeze(0) prev_img = img prev_embedding = current_embedding #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(current_frame_num) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) pred_masks_reverse.reverse() pred_masks_reverse.extend(pred_masks) final_masks = torch.cat(pred_masks_reverse,0) sess.submit_masks(final_masks.cpu().numpy()) t_end = timeit.default_timer() print('Total time for single interaction: ' + str(t_end - t_total)) report = sess.get_report() summary = sess.get_global_summary(save_file=os.path.join(report_save_dir, 'summary.json'))
net = UNet(1, args.num_classes).cuda() if args.is_pretrain: print('loading pretrain') state_dict = torch.load('/home/viplab/data/unet_carvana_scale1_epoch5.pth') model_dict = net.state_dict() pretrained_state = { k:v for k,v in state_dict.items() \ if k in model_dict and v.size() == model_dict[k].size() } model_dict.update(pretrained_state) net.load_state_dict(model_dict) elif args.model == 'denseunet': if args.is_pretrain: net = DenseUNet(args.num_classes, pretrained_encoder_uri='https://download.pytorch.org/models/densenet121-a639ec97.pth').cuda() else: net = DenseUNet(args.num_classes).cuda() elif args.model == 'deeplab': net = DeepLab(sync_bn=False, num_classes=args.num_classes).cuda() elif args.model == 'deeplab_xception': net = DeepLab(sync_bn=False, num_classes=args.num_classes, backbone='xception', output_stride=8).cuda() elif args.model == 'deeplab_resnest': MODEL_CFG = MODEL_CFG.copy() MODEL_CFG.update({'num_classes':args.num_classes}) if args.norm == 'groupnorm': MODEL_CFG.update({'norm_cfg': {'type': 'groupnorm', 'opts': {}}}) net = Deeplabv3Plus(MODEL_CFG, mode='TRAIN').cuda() if args.is_pretrain: # resume state_dict = torch.load(args.is_pretrain) print('loading', args.is_pretrain) if 'model' in state_dict.keys(): state_dict = state_dict['model'] new_state_dict = {} for k in state_dict.keys():