Exemplo n.º 1
0
 def transform_val(self, sample):
     composed_transforms = transforms.Compose([
         tr.Resize(self.args.test_size,shrink=self.args.shrink),
         tr.Normalize(mean=self.args.normal_mean,std=self.args.normal_std,bgr_mode=self.args.bgr_mode,gray_mode=self.args.gray_mode),
         tr.ToTensor()
     ])
     return composed_transforms(sample)
 def transform_validation(self):
     temp = []
     temp.append(tr.Resize(self.args.input_size))
     # temp.append(tr.RandomCrop(self.args.input_size))
     temp.append(tr.ToTensor())
     composed_transforms = transforms.Compose(temp)
     return composed_transforms
Exemplo n.º 3
0
    def transform_val(self, sample):
        composed_transforms = transforms.Compose([
            tr.Resize(self.args.test_size, shrink=self.args.shrink, pad=True),
            tr.Normalize(mean=self.args.normal_mean, std=self.args.normal_std),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
 def transform_train(self):
     temp = []
     temp.append(tr.Resize(self.args.input_size))
     temp.append(tr.RandomHorizontalFlip())
     temp.append(tr.RandomRotate(15))
     temp.append(tr.RandomCrop(self.args.input_size))
     temp.append(tr.ToTensor())
     composed_transforms = transforms.Compose(temp)
     return composed_transforms
Exemplo n.º 5
0
    def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False):
        ###################
        self.model.train()
        running_loss = AverageMeter()
        #optimizer = optim.SGD([{'params':self.model.feature_extracter.parameters()},{'params':self.model.semantic_embedding.parameters()},{'params':self.model.dynamic_seghead.parameters()}],lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM)
        optimizer = optim.SGD(self.model.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY)
        #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA)
        

        ###################

        composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        composed_transforms_ytb = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale([0.5,1,1.25]),
                                                     tr.RandomCrop((800,800)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        print('dataset processing...')
#        train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        ytb_train_dataset = YTB_VOS_Train(root=cfg.YTB_DATAROOT,transform=composed_transforms_ytb)
#        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
#                        sampler = RandomIdentitySampler(train_dataset.sample_list), 
#                        shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_davis = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_ytb = DataLoader(ytb_train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        #trainloader=[trainloader_ytb,trainloader_davis]
        trainloader=[trainloader_ytb,trainloader_davis]
        print('dataset processing finished.')
        if lossfunc=='bce':
            criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        elif lossfunc=='cross_entropy':
            criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        else:
            print('unsupported loss funciton. Please choose from [cross_entropy,bce]')

        max_itr = cfg.TRAIN_TOTAL_STEPS

        step=0

        if model_resume:
            saved_model_=os.path.join(self.save_res_dir,'save_step_60000.pth')

            saved_model_ = torch.load(saved_model_)
            self.model=self.load_network(self.model,saved_model_)
            step=60000
            print('resume from step {}'.format(step))
        while step<cfg.TRAIN_TOTAL_STEPS:

            for train_dataloader in trainloader:
 #       sample['meta']={'seq_name':seqname,'frame_num':frame_num,'obj_num':obj_num}
                for ii, sample in enumerate(train_dataloader):
        #            print(ii)
                    now_lr=self._adjust_lr(optimizer,step,max_itr)
                    ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                    img1s = sample['img1'] 
                    img2s = sample['img2']
                    ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                    label1s = sample['label1']
                    label2s = sample['label2']
                    seq_names = sample['meta']['seq_name'] 
                    obj_nums = sample['meta']['obj_num']
                    bs,_,h,w = img2s.size()
                    inputs = torch.cat((ref_imgs,img1s,img2s),0)
                    if damage_initial_previous_frame_mask:
                        try:
                            label1s = damage_masks(label1s)
                        except:
                            label1s = label1s
                            print('damage_error')




                    ##########
                    if self.use_gpu:
                        inputs = inputs.cuda()
                        ref_scribble_labels=ref_scribble_labels.cuda()
                        label1s = label1s.cuda()
                        label2s = label2s.cuda()
                     
                    ##########


                    tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS)

                    label_and_obj_dic={}
                    label_dic={}
                    for i, seq_ in enumerate(seq_names):
                        label_and_obj_dic[seq_]=(label2s[i],obj_nums[i])
                    for seq_ in tmp_dic.keys():
                        tmp_pred_logits = tmp_dic[seq_]
                        tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True)
                        tmp_dic[seq_]=tmp_pred_logits

                        label_tmp,obj_num = label_and_obj_dic[seq_]
                        obj_ids = np.arange(1,obj_num+1)
                        obj_ids = torch.from_numpy(obj_ids)
                        obj_ids = obj_ids.int()
                        if torch.cuda.is_available():
                            obj_ids = obj_ids.cuda()
                        if lossfunc == 'bce':
                            label_tmp = label_tmp.permute(1,2,0)
                            label = (label_tmp.float()==obj_ids.float())
                            label = label.unsqueeze(-1).permute(3,2,0,1)
                            label_dic[seq_]=label.float()
                        elif lossfunc =='cross_entropy':
                            label_dic[seq_]=label_tmp.long()



                    loss = criterion(tmp_dic,label_dic,step)
                    loss =loss/bs
######################################
                     
                    if loss.item()>10000:
                        print(tmp_dic)
                        for k,v in tmp_dic.items():
                            v = v.cpu()
                            v = v.detach().numpy()
                            np.save(k+'.npy',v)
                            l=label_dic[k]
                            l=l.cpu().detach().numpy()
                            np.save('lab'+k+'.npy',l)
                        #continue 
                        exit()
##########################################
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    #scheduler.step()
                    
                    running_loss.update(loss.item(),bs)
                    if step%1==0:
                        #print(torch.cuda.memory_allocated())
                        #print(torch.cuda.max_memory_cached())
                        #torch.cuda.empty_cache()
                        #torch.cuda.reset_max_memory_allocated()
                        print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg))
                        #print(tmp_dic)
                        #print(seq_names)
                    #    print('step:{}'.format(step))
                        
                        show_ref_img = ref_imgs.cpu().numpy()[0]
                        show_img1 = img1s.cpu().numpy()[0]
                        show_img2 = img2s.cpu().numpy()[0]

                        mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
                        sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])

                        show_ref_img = show_ref_img*sigma+mean
                        show_img1 = show_img1*sigma+mean
                        show_img2 = show_img2*sigma+mean


                        show_gt = label2s.cpu()[0]

                        show_gt = show_gt.squeeze(0).numpy()
                        show_gtf = label2colormap(show_gt).transpose((2,0,1))

                        ##########
                        show_preds = tmp_dic[seq_names[0]].cpu()
                        show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True)
                        show_preds = show_preds.squeeze(0)
                        if lossfunc=='bce':
                            show_preds = (torch.sigmoid(show_preds)>0.5)
                            show_preds_s = torch.zeros((h,w))
                            for i in range(show_preds.size(0)):
                                show_preds_s[show_preds[i]]=i+1
                        elif lossfunc=='cross_entropy':
                            show_preds_s = torch.argmax(show_preds,dim=0)
                        show_preds_s = show_preds_s.numpy()
                        show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1))

                        pix_acc = np.sum(show_preds_s==show_gt)/(h*w)




                        tblogger.add_scalar('loss', running_loss.avg, step)
                        tblogger.add_scalar('pix_acc', pix_acc, step)
                        tblogger.add_scalar('now_lr', now_lr, step)
                        tblogger.add_image('Reference image', show_ref_img, step)
                        tblogger.add_image('Previous frame image', show_img1, step)
                        tblogger.add_image('Current frame image', show_img2, step)
                        tblogger.add_image('Groud Truth', show_gtf, step)
                        tblogger.add_image('Predict label', show_preds_sf, step)




                        ###########TODO
                    if step%5000==0 and step!=0:
                        self.save_network(self.model,step)



                    step+=1
Exemplo n.º 6
0
    def get_img_size(self):
        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0]))

        return list(img.shape[:2])


if __name__ == '__main__':
    import custom_transforms as tr
    import torch
    from torchvision import transforms
    from matplotlib import pyplot as plt

    transforms = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.Resize(scales=[0.5, 0.8, 1]),
        tr.ToTensor()
    ])

    dataset = OnlineDataset(train=True, transform=tr.ToTensor())
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=1)

    for i, data in enumerate(dataloader):
        img = data['image']
        # print img
        print(img.shape)
        bb
        plt.figure()
Exemplo n.º 7
0
    print('Loading from snapshot: {}'.format(args.weights))
    net.load_state_dict(
        torch.load(args.weights, map_location=lambda storage, loc: storage))
    net.to(device)
    net.eval()

    # Setup data transformations
    composed_transforms = [
        tr.IdentityTransform(tr_elems=['gt'], prefix='ori_'),
        tr.CropFromMask(crop_elems=['image', 'gt'],
                        relax=cfg['relax_crop'],
                        zero_pad=cfg['zero_pad_crop'],
                        adaptive_relax=cfg['adaptive_relax'],
                        prefix=''),
        tr.Resize(resize_elems=['image', 'gt', 'void_pixels'],
                  min_size=cfg['min_size'],
                  max_size=cfg['max_size']),
        tr.ComputeImageGradient(elem='image'),
        tr.ExtremePoints(sigma=10, pert=0, elem='gt'),
        tr.GaussianTransform(tr_elems=['extreme_points'],
                             mask_elem='gt',
                             sigma=10,
                             tr_name='points'),
        tr.FixedResizePoints(
            resolutions={'extreme_points': (cfg['lr_size'], cfg['lr_size'])},
            mask_elem='gt',
            prefix='lr_'),
        tr.FixedResize(resolutions={
            'image': (cfg['lr_size'], cfg['lr_size']),
            'gt': (cfg['lr_size'], cfg['lr_size']),
            'void_pixels': (cfg['lr_size'], cfg['lr_size'])
Exemplo n.º 8
0
        return data, gt

    def get_img_size(self):
        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0]))

        return list(img.shape[:2])


if __name__ == '__main__':
    import custom_transforms as tr
    import torch
    from torchvision import transforms
    from matplotlib import pyplot as plt

    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])

    dataset = OnlineDataset(train=True, transform=tr.ToTensor())
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)

    for i, data in enumerate(dataloader):
        img = data['image']
        # print img
        print(img.shape)
        bb
        plt.figure()
        plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))
        if i == 10:
            break

    plt.show(block=True)
Exemplo n.º 9
0
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    p['optimizer'] = str(optimizer)

    # Setup data transformations
    composed_transforms = [
        tr.RandomHorizontalFlip(),
        tr.CropFromMask(crop_elems=['image', 'gt', 'thin', 'void_pixels'],
                        relax=args.relax_crop,
                        zero_pad=args.zero_pad_crop,
                        adaptive_relax=args.adaptive_relax,
                        prefix=''),
        tr.Resize(resize_elems=['image', 'gt', 'thin', 'void_pixels'],
                  min_size=args.min_size,
                  max_size=args.max_size),
        tr.ComputeImageGradient(elem='image'),
        tr.ExtremePoints(sigma=10, pert=5, elem='gt'),
        tr.GaussianTransform(tr_elems=['extreme_points'],
                             mask_elem='gt',
                             sigma=10,
                             tr_name='points'),
        tr.RandomCrop(
            num_thin=args.num_thin_samples,
            num_non_thin=args.num_non_thin_samples,
            crop_size=args.roi_size,
            prefix='crop_',
            thin_elem='thin',
            crop_elems=['image', 'gt', 'points', 'void_pixels', 'image_grad']),
        tr.MatchROIs(crop_elem='gt', resolution=args.lr_size),