Exemplo n.º 1
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.scaleNorm(),
            tr.RandomScale((1.0, 1.4)),
            tr.RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25)),
            tr.RandomCrop(image_h, image_w),
            tr.RandomFlip(),
            tr.ToTensor(),
            tr.Normalize()
        ])

        return composed_transforms(sample)
Exemplo n.º 2
0
 def transform_tr(self, sample):
     composed_transforms = transforms.Compose([
         tr.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
         #tr.RandomGaussianBlur(),
         tr.HorizontalFlip(),
         tr.RandomScale(),
         tr.RandomCrop(size=(self.args.crop_size, self.args.crop_size)),
         tr.Normalize(mean=(0.485, 0.456, 0.406),
                      std=(0.229, 0.224, 0.225)),
         tr.ToTensor()
     ])
     return composed_transforms(sample)
Exemplo n.º 3
0
    def transform_tr(self, sample):
        temp = []
        if self.args.rotate > 0:
            temp.append(tr.RandomRotate(self.args.rotate))
        temp.append(tr.RandomScale(rand_resize=self.args.rand_resize))
        temp.append(tr.RandomCrop(self.args.input_size))
        temp.append(tr.RandomHorizontalFlip())
        temp.append(
            tr.Normalize(mean=self.args.normal_mean, std=self.args.normal_std))
        if self.args.noise_param is not None:
            temp.append(
                tr.GaussianNoise(mean=self.args.noise_param[0],
                                 std=self.args.noise_param[1]))
        temp.append(tr.ToTensor())
        composed_transforms = transforms.Compose(temp)

        return composed_transforms(sample)
Exemplo n.º 4
0
            mask[mask == _voidc] = self.ignore_index
        for _validc in self.valid_classes:
            mask[mask == _validc] = self.class_map[_validc]
        return mask


if __name__ == '__main__':
    from dataloaders import custom_transforms as tr
    from dataloaders.utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt

    composed_transforms_tr = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.RandomScale((0.5, 0.75)),
        tr.RandomCrop((512, 1024)),
        tr.RandomRotate(5),
        tr.ToTensor()
    ])

    cityscapes_train = CityscapesSegmentation(split='train',
                                              transform=composed_transforms_tr)

    dataloader = DataLoader(cityscapes_train,
                            batch_size=2,
                            shuffle=True,
                            num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
Exemplo n.º 5
0
    def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False):
        ###################
        self.model.train()
        running_loss = AverageMeter()
        #optimizer = optim.SGD([{'params':self.model.feature_extracter.parameters()},{'params':self.model.semantic_embedding.parameters()},{'params':self.model.dynamic_seghead.parameters()}],lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM)
        optimizer = optim.SGD(self.model.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY)
        #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA)
        

        ###################

        composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        composed_transforms_ytb = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale([0.5,1,1.25]),
                                                     tr.RandomCrop((800,800)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        print('dataset processing...')
#        train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        ytb_train_dataset = YTB_VOS_Train(root=cfg.YTB_DATAROOT,transform=composed_transforms_ytb)
#        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
#                        sampler = RandomIdentitySampler(train_dataset.sample_list), 
#                        shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_davis = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_ytb = DataLoader(ytb_train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        #trainloader=[trainloader_ytb,trainloader_davis]
        trainloader=[trainloader_ytb,trainloader_davis]
        print('dataset processing finished.')
        if lossfunc=='bce':
            criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        elif lossfunc=='cross_entropy':
            criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        else:
            print('unsupported loss funciton. Please choose from [cross_entropy,bce]')

        max_itr = cfg.TRAIN_TOTAL_STEPS

        step=0

        if model_resume:
            saved_model_=os.path.join(self.save_res_dir,'save_step_60000.pth')

            saved_model_ = torch.load(saved_model_)
            self.model=self.load_network(self.model,saved_model_)
            step=60000
            print('resume from step {}'.format(step))
        while step<cfg.TRAIN_TOTAL_STEPS:

            for train_dataloader in trainloader:
 #       sample['meta']={'seq_name':seqname,'frame_num':frame_num,'obj_num':obj_num}
                for ii, sample in enumerate(train_dataloader):
        #            print(ii)
                    now_lr=self._adjust_lr(optimizer,step,max_itr)
                    ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                    img1s = sample['img1'] 
                    img2s = sample['img2']
                    ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                    label1s = sample['label1']
                    label2s = sample['label2']
                    seq_names = sample['meta']['seq_name'] 
                    obj_nums = sample['meta']['obj_num']
                    bs,_,h,w = img2s.size()
                    inputs = torch.cat((ref_imgs,img1s,img2s),0)
                    if damage_initial_previous_frame_mask:
                        try:
                            label1s = damage_masks(label1s)
                        except:
                            label1s = label1s
                            print('damage_error')




                    ##########
                    if self.use_gpu:
                        inputs = inputs.cuda()
                        ref_scribble_labels=ref_scribble_labels.cuda()
                        label1s = label1s.cuda()
                        label2s = label2s.cuda()
                     
                    ##########


                    tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS)

                    label_and_obj_dic={}
                    label_dic={}
                    for i, seq_ in enumerate(seq_names):
                        label_and_obj_dic[seq_]=(label2s[i],obj_nums[i])
                    for seq_ in tmp_dic.keys():
                        tmp_pred_logits = tmp_dic[seq_]
                        tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True)
                        tmp_dic[seq_]=tmp_pred_logits

                        label_tmp,obj_num = label_and_obj_dic[seq_]
                        obj_ids = np.arange(1,obj_num+1)
                        obj_ids = torch.from_numpy(obj_ids)
                        obj_ids = obj_ids.int()
                        if torch.cuda.is_available():
                            obj_ids = obj_ids.cuda()
                        if lossfunc == 'bce':
                            label_tmp = label_tmp.permute(1,2,0)
                            label = (label_tmp.float()==obj_ids.float())
                            label = label.unsqueeze(-1).permute(3,2,0,1)
                            label_dic[seq_]=label.float()
                        elif lossfunc =='cross_entropy':
                            label_dic[seq_]=label_tmp.long()



                    loss = criterion(tmp_dic,label_dic,step)
                    loss =loss/bs
######################################
                     
                    if loss.item()>10000:
                        print(tmp_dic)
                        for k,v in tmp_dic.items():
                            v = v.cpu()
                            v = v.detach().numpy()
                            np.save(k+'.npy',v)
                            l=label_dic[k]
                            l=l.cpu().detach().numpy()
                            np.save('lab'+k+'.npy',l)
                        #continue 
                        exit()
##########################################
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    #scheduler.step()
                    
                    running_loss.update(loss.item(),bs)
                    if step%1==0:
                        #print(torch.cuda.memory_allocated())
                        #print(torch.cuda.max_memory_cached())
                        #torch.cuda.empty_cache()
                        #torch.cuda.reset_max_memory_allocated()
                        print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg))
                        #print(tmp_dic)
                        #print(seq_names)
                    #    print('step:{}'.format(step))
                        
                        show_ref_img = ref_imgs.cpu().numpy()[0]
                        show_img1 = img1s.cpu().numpy()[0]
                        show_img2 = img2s.cpu().numpy()[0]

                        mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
                        sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])

                        show_ref_img = show_ref_img*sigma+mean
                        show_img1 = show_img1*sigma+mean
                        show_img2 = show_img2*sigma+mean


                        show_gt = label2s.cpu()[0]

                        show_gt = show_gt.squeeze(0).numpy()
                        show_gtf = label2colormap(show_gt).transpose((2,0,1))

                        ##########
                        show_preds = tmp_dic[seq_names[0]].cpu()
                        show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True)
                        show_preds = show_preds.squeeze(0)
                        if lossfunc=='bce':
                            show_preds = (torch.sigmoid(show_preds)>0.5)
                            show_preds_s = torch.zeros((h,w))
                            for i in range(show_preds.size(0)):
                                show_preds_s[show_preds[i]]=i+1
                        elif lossfunc=='cross_entropy':
                            show_preds_s = torch.argmax(show_preds,dim=0)
                        show_preds_s = show_preds_s.numpy()
                        show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1))

                        pix_acc = np.sum(show_preds_s==show_gt)/(h*w)




                        tblogger.add_scalar('loss', running_loss.avg, step)
                        tblogger.add_scalar('pix_acc', pix_acc, step)
                        tblogger.add_scalar('now_lr', now_lr, step)
                        tblogger.add_image('Reference image', show_ref_img, step)
                        tblogger.add_image('Previous frame image', show_img1, step)
                        tblogger.add_image('Current frame image', show_img2, step)
                        tblogger.add_image('Groud Truth', show_gtf, step)
                        tblogger.add_image('Predict label', show_preds_sf, step)




                        ###########TODO
                    if step%5000==0 and step!=0:
                        self.save_network(self.model,step)



                    step+=1