def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False): ################### self.model.train() running_loss = AverageMeter() #optimizer = optim.SGD([{'params':self.model.feature_extracter.parameters()},{'params':self.model.semantic_embedding.parameters()},{'params':self.model.dynamic_seghead.parameters()}],lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM) optimizer = optim.SGD(self.model.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY) #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA) ################### composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale(), tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) composed_transforms_ytb = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale([0.5,1,1.25]), tr.RandomCrop((800,800)), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) print('dataset processing...') # train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms) train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms) ytb_train_dataset = YTB_VOS_Train(root=cfg.YTB_DATAROOT,transform=composed_transforms_ytb) # trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, # sampler = RandomIdentitySampler(train_dataset.sample_list), # shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True) trainloader_davis = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True) trainloader_ytb = DataLoader(ytb_train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True) #trainloader=[trainloader_ytb,trainloader_davis] trainloader=[trainloader_ytb,trainloader_davis] print('dataset processing finished.') if lossfunc=='bce': criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) elif lossfunc=='cross_entropy': criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) else: print('unsupported loss funciton. Please choose from [cross_entropy,bce]') max_itr = cfg.TRAIN_TOTAL_STEPS step=0 if model_resume: saved_model_=os.path.join(self.save_res_dir,'save_step_60000.pth') saved_model_ = torch.load(saved_model_) self.model=self.load_network(self.model,saved_model_) step=60000 print('resume from step {}'.format(step)) while step<cfg.TRAIN_TOTAL_STEPS: for train_dataloader in trainloader: # sample['meta']={'seq_name':seqname,'frame_num':frame_num,'obj_num':obj_num} for ii, sample in enumerate(train_dataloader): # print(ii) now_lr=self._adjust_lr(optimizer,step,max_itr) ref_imgs = sample['ref_img'] #batch_size * 3 * h * w img1s = sample['img1'] img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w label1s = sample['label1'] label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] bs,_,h,w = img2s.size() inputs = torch.cat((ref_imgs,img1s,img2s),0) if damage_initial_previous_frame_mask: try: label1s = damage_masks(label1s) except: label1s = label1s print('damage_error') ########## if self.use_gpu: inputs = inputs.cuda() ref_scribble_labels=ref_scribble_labels.cuda() label1s = label1s.cuda() label2s = label2s.cuda() ########## tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS) label_and_obj_dic={} label_dic={} for i, seq_ in enumerate(seq_names): label_and_obj_dic[seq_]=(label2s[i],obj_nums[i]) for seq_ in tmp_dic.keys(): tmp_pred_logits = tmp_dic[seq_] tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True) tmp_dic[seq_]=tmp_pred_logits label_tmp,obj_num = label_and_obj_dic[seq_] obj_ids = np.arange(1,obj_num+1) obj_ids = torch.from_numpy(obj_ids) obj_ids = obj_ids.int() if torch.cuda.is_available(): obj_ids = obj_ids.cuda() if lossfunc == 'bce': label_tmp = label_tmp.permute(1,2,0) label = (label_tmp.float()==obj_ids.float()) label = label.unsqueeze(-1).permute(3,2,0,1) label_dic[seq_]=label.float() elif lossfunc =='cross_entropy': label_dic[seq_]=label_tmp.long() loss = criterion(tmp_dic,label_dic,step) loss =loss/bs ###################################### if loss.item()>10000: print(tmp_dic) for k,v in tmp_dic.items(): v = v.cpu() v = v.detach().numpy() np.save(k+'.npy',v) l=label_dic[k] l=l.cpu().detach().numpy() np.save('lab'+k+'.npy',l) #continue exit() ########################################## optimizer.zero_grad() loss.backward() optimizer.step() #scheduler.step() running_loss.update(loss.item(),bs) if step%1==0: #print(torch.cuda.memory_allocated()) #print(torch.cuda.max_memory_cached()) #torch.cuda.empty_cache() #torch.cuda.reset_max_memory_allocated() print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg)) #print(tmp_dic) #print(seq_names) # print('step:{}'.format(step)) show_ref_img = ref_imgs.cpu().numpy()[0] show_img1 = img1s.cpu().numpy()[0] show_img2 = img2s.cpu().numpy()[0] mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) show_ref_img = show_ref_img*sigma+mean show_img1 = show_img1*sigma+mean show_img2 = show_img2*sigma+mean show_gt = label2s.cpu()[0] show_gt = show_gt.squeeze(0).numpy() show_gtf = label2colormap(show_gt).transpose((2,0,1)) ########## show_preds = tmp_dic[seq_names[0]].cpu() show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True) show_preds = show_preds.squeeze(0) if lossfunc=='bce': show_preds = (torch.sigmoid(show_preds)>0.5) show_preds_s = torch.zeros((h,w)) for i in range(show_preds.size(0)): show_preds_s[show_preds[i]]=i+1 elif lossfunc=='cross_entropy': show_preds_s = torch.argmax(show_preds,dim=0) show_preds_s = show_preds_s.numpy() show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1)) pix_acc = np.sum(show_preds_s==show_gt)/(h*w) tblogger.add_scalar('loss', running_loss.avg, step) tblogger.add_scalar('pix_acc', pix_acc, step) tblogger.add_scalar('now_lr', now_lr, step) tblogger.add_image('Reference image', show_ref_img, step) tblogger.add_image('Previous frame image', show_img1, step) tblogger.add_image('Current frame image', show_img2, step) tblogger.add_image('Groud Truth', show_gtf, step) tblogger.add_image('Predict label', show_preds_sf, step) ###########TODO if step%5000==0 and step!=0: self.save_network(self.model,step) step+=1
def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False,eval_total=False,init_prev=False): ################### interactor = interactive_robot.InteractiveScribblesRobot() self.model.train() running_loss = AverageMeter() optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY) # optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM) #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA) ################### composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale(), tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP),10), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) print('dataset processing...') train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms) train_list = train_dataset.seqs # train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms) # trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, # shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True) print('dataset processing finished.') if lossfunc=='bce': criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) elif lossfunc=='cross_entropy': criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) else: print('unsupported loss funciton. Please choose from [cross_entropy,bce]') max_itr = cfg.TRAIN_TOTAL_STEPS step=0 round_=3 epoch_per_round=30 if model_resume: saved_model_=os.path.join(self.save_res_dir,'save_step_75000.pth') saved_model_ = torch.load(saved_model_) self.model=self.load_network(self.model,saved_model_) step=75000 print('resume from step {}'.format(step)) while step<cfg.TRAIN_TOTAL_STEPS: if step>100001: break for r in range(round_): if r==0: print('start new') global_map_tmp_dic={} train_dataset.transform=transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale(), tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) train_dataset.init_ref_frame_dic() trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, sampler = RandomIdentitySampler(train_dataset.sample_list), shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True) print('round:{} start'.format(r)) for epoch in range(epoch_per_round): for ii, sample in enumerate(trainloader): now_lr=self._adjust_lr(optimizer,step,max_itr) ref_imgs = sample['ref_img'] #batch_size * 3 * h * w #img1s = sample['img1'] #img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w #label1s = sample['label1'] #label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] #frame_nums = sample['meta']['frame_num'] ref_frame_nums = sample['meta']['ref_frame_num'] ref_frame_gts=sample['ref_frame_gt'] bs,_,h,w = ref_imgs.size() # print(ref_imgs.size()) # if r==0: # ref_scribble_labels=self.rough_ROI(ref_scribble_labels) ########## if self.use_gpu: inputs = ref_imgs.cuda() ref_scribble_labels=ref_scribble_labels.cuda() ref_frame_gts = ref_frame_gts.cuda() #label1s = label1s.cuda() #label2s = label2s.cuda() #print(inputs.size()) ########## with torch.no_grad(): self.model.feature_extracter.eval() self.model.semantic_embedding.eval() ref_frame_embedding = self.model.extract_feature(inputs) if r==0: first_inter=True tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels, prev_round_label=None,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={}, seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS, frame_num=ref_frame_nums,first_inter=first_inter) else: first_inter=False prev_round_label=sample['prev_round_label'] # print(prev_round_label.size()) #prev_round_label=prev_round_label_dic[seq_names[0]] prev_round_label=prev_round_label.cuda() tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels, prev_round_label=prev_round_label,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={}, seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS, frame_num=ref_frame_nums,first_inter=first_inter) label_and_obj_dic={} label_dic={} for i, seq_ in enumerate(seq_names): label_and_obj_dic[seq_]=(ref_frame_gts[i],obj_nums[i]) for seq_ in tmp_dic.keys(): tmp_pred_logits = tmp_dic[seq_] tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True) tmp_dic[seq_]=tmp_pred_logits label_tmp,obj_num = label_and_obj_dic[seq_] obj_ids = np.arange(0,obj_num+1) obj_ids = torch.from_numpy(obj_ids) obj_ids = obj_ids.int() if torch.cuda.is_available(): obj_ids = obj_ids.cuda() if lossfunc == 'bce': label_tmp = label_tmp.permute(1,2,0) label = (label_tmp.float()==obj_ids.float()) label = label.unsqueeze(-1).permute(3,2,0,1) label_dic[seq_]=label.float() elif lossfunc =='cross_entropy': label_dic[seq_]=label_tmp.long() loss = criterion(tmp_dic,label_dic,step) loss =loss/bs optimizer.zero_grad() loss.backward() optimizer.step() #scheduler.step() running_loss.update(loss.item(),bs) if step%50==0: #print(torch.cuda.memory_allocated()) #print(torch.cuda.max_memory_cached()) torch.cuda.empty_cache() #torch.cuda.reset_max_memory_allocated() print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg)) # print('step:{}'.format(step)) show_ref_img = ref_imgs.cpu().numpy()[0] #show_img1 = img1s.cpu().numpy()[0] #show_img2 = img2s.cpu().numpy()[0] mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) show_ref_img = show_ref_img*sigma+mean #show_img1 = show_img1*sigma+mean #show_img2 = show_img2*sigma+mean #show_gt = label2s.cpu()[0] show_gt = ref_frame_gts.cpu()[0].squeeze(0).numpy() show_gtf = label2colormap(show_gt).transpose((2,0,1)) show_scrbble=ref_scribble_labels.cpu()[0].squeeze(0).numpy() show_scrbble=label2colormap(show_scrbble).transpose((2,0,1)) if r!=0: show_prev_round_label=prev_round_label.cpu()[0].squeeze(0).numpy() show_prev_round_label=label2colormap(show_prev_round_label).transpose((2,0,1)) else: show_prev_round_label = np.zeros_like(show_gt) show_prev_round_label = label2colormap(show_prev_round_label).transpose((2,0,1)) ########## show_preds = tmp_dic[seq_names[0]].cpu() show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True) show_preds = show_preds.squeeze(0) if lossfunc=='bce': show_preds = show_preds[1:] show_preds = (torch.sigmoid(show_preds)>0.5) marker = torch.argmax(show_preds,dim=0) show_preds_s = torch.zeros((h,w)) for i in range(show_preds.size(0)): tmp_mask = (marker==i) & (show_preds[i]>0.5) show_preds_s[tmp_mask]=i+1 elif lossfunc=='cross_entropy': show_preds_s = torch.argmax(show_preds,dim=0) show_preds_s = show_preds_s.numpy() show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1)) pix_acc = np.sum(show_preds_s==show_gt)/(h*w) if cfg.TRAIN_TBLOG: tblogger.add_scalar('loss', running_loss.avg, step) tblogger.add_scalar('pix_acc', pix_acc, step) tblogger.add_scalar('now_lr', now_lr, step) tblogger.add_image('Reference image', show_ref_img, step) #tblogger.add_image('Previous frame image', show_img1, step) # tblogger.add_image('Current frame image', show_img2, step) tblogger.add_image('Groud Truth', show_gtf, step) tblogger.add_image('Predict label', show_preds_sf, step) tblogger.add_image('Scribble', show_scrbble, step) tblogger.add_image('prev_round_label', show_prev_round_label, step) ###########TODO if step%20000==0 and step!=0: self.save_network(self.model,step) step+=1 print('trainset evaluating...') print('*'*100) if cfg.TRAIN_INTER_USE_TRUE_RESULT: if r!=round_-1: if r ==0: prev_round_label_dic={} self.model.eval() with torch.no_grad(): round_scribble={} frame_num_dic= {} train_dataset.transform=transforms.Compose([tr.Resize(cfg.DATA_RESCALE),tr.ToTensor()]) # train_dataset.transform=composed_transforms trainloader = DataLoader(train_dataset,batch_size=1, sampler = RandomIdentitySampler(train_dataset.sample_list), shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True) for ii, sample in enumerate(trainloader): ref_imgs = sample['ref_img'] #batch_size * 3 * h * w img1s = sample['img1'] img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w label1s = sample['label1'] label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] frame_nums = sample['meta']['frame_num'] bs,_,h,w = img2s.size() inputs = torch.cat((ref_imgs,img1s,img2s),0) if r==0: ref_scribble_labels=self.rough_ROI(ref_scribble_labels) print(seq_names[0]) # if damage_initial_previous_frame_mask: # try: # label1s = damage_masks(label1s) # except: # label1s = label1s # print('damage_error') label1s_tocat=torch.Tensor() for i in range(bs): l = label1s[i] l = l.unsqueeze(0) l = mask_damager(l,0.0) l = torch.from_numpy(l) l = l.unsqueeze(0).unsqueeze(0) label1s_tocat = torch.cat((label1s_tocat,l.float()),0) label1s = label1s_tocat if self.use_gpu: inputs = inputs.cuda() ref_scribble_labels=ref_scribble_labels.cuda() label1s = label1s.cuda() tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names, gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic, frame_num=frame_nums) pred_label = tmp_dic[seq_names[0]].detach().cpu() pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_label= pred_label.unsqueeze(0) try: pred_label=damage_masks(pred_label) except: pred_label=pred_label pred_label=pred_label.squeeze(0) round_scribble[seq_names[0]]=interactor.interact(seq_names[0],pred_label.numpy(),label2s.float().squeeze(0).numpy(),obj_nums) frame_num_dic[seq_names[0]]=frame_nums[0] pred_label=pred_label.unsqueeze(0) img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg')) img_ww=np.array(img_ww) or_h,or_w = img_ww.shape[:2] pred_label = torch.nn.functional.interpolate(pred_label.float(),(or_h,or_w),mode='nearest') prev_round_label_dic[seq_names[0]]=pred_label.squeeze(0) # torch.cuda.empty_cache() train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic) print('trainset evaluating finished!') print('*'*100) self.model.train() print('updating ref frame and label') train_dataset.transform=composed_transforms print('updating ref frame and label finished!') else: if r!=round_-1: round_scribble={} if r ==0: prev_round_label_dic={} # eval_global_map_tmp_dic={} frame_num_dic= {} train_dataset.transform=tr.ToTensor() # train_dataset.transform=tr.ToTensor() trainloader = DataLoader(train_dataset,batch_size=1, sampler = RandomIdentitySampler(train_dataset.sample_list), shuffle=False,num_workers=1,pin_memory=True) self.model.eval() with torch.no_grad(): for ii, sample in enumerate(trainloader): ref_imgs = sample['ref_img'] #batch_size * 3 * h * w img1s = sample['img1'] img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w label1s = sample['label1'] label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] frame_nums = sample['meta']['frame_num'] bs,_,h,w = img2s.size() # inputs=torch.cat((ref_imgs,img1s,img2s),0) # if r==0: # ref_scribble_labels=self.rough_ROI(ref_scribble_labels) print(seq_names[0]) label2s_ = mask_damager(label2s,0.1) # inputs = inputs.cuda() # ref_scribble_labels=ref_scribble_labels.cuda() # label1s = label1s.cuda() # tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names, # gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic, # frame_num=frame_nums) #label2s_show = label2s.squeeze().numpy() #label2s_im = Image.fromarray(label2s_show.astype('uint8')).convert('P') #label2s_im.putpalette(_palette) #label2s_im.save('label.png') #label2s__show = label2s_ #label2s_im_ = Image.fromarray(label2s__show.astype('uint8')).convert('P') #label2s_im_.putpalette(_palette) #label2s_im_.save('damage_label.png') # exit() #print(label2s_) #print(label2s.size()) round_scribble[seq_names[0]]=interactor.interact(seq_names[0],np.expand_dims(label2s_,axis=0),label2s.float().squeeze(0).numpy(),obj_nums) label2s__=torch.from_numpy(label2s_) # img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg')) # img_ww=np.array(img_ww) # or_h,or_w = img_ww.shape[:2] # label2s__=label2s__.unsqueeze(0).unsqueeze(0) # label2s__ = torch.nn.functional.interpolate(label2s__.float(),(or_h,or_w),mode='nearest') # label2s__=label2s__.squeeze(0) # print(label2s__.size()) frame_num_dic[seq_names[0]]=frame_nums[0] prev_round_label_dic[seq_names[0]]=label2s__ #torch.cuda.empty_cache() print('trainset evaluating finished!') print('*'*100) print('updating ref frame and label') train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic) self.model.train() train_dataset.transform=composed_transforms print('updating ref frame and label finished!')