Esempio n. 1
0
 def transform_tr(self, sample):
     composed_transforms = transforms.Compose([
         tr.RandomCrop(self.par.base_size, self.par.crop_size, fill=255),
         tr.RandomColorJitter(),
         tr.RandomHorizontalFlip(),
         tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
         tr.ToTensor()])
     return composed_transforms(sample)
 def transform_train(self):
     temp = []
     temp.append(tr.Resize(self.args.input_size))
     temp.append(tr.RandomHorizontalFlip())
     temp.append(tr.RandomRotate(15))
     temp.append(tr.RandomCrop(self.args.input_size))
     temp.append(tr.ToTensor())
     composed_transforms = transforms.Compose(temp)
     return composed_transforms
Esempio n. 3
0
 def transform_tr(self, sample: dict):
     sample_transforms = transforms.Compose([
         ctr.RandomCrop(size=self.settings['rnd_crop_size']),
         ctr.RandomHorizontalFlip(p=0.5),
         ctr.ToTensor(),
         ctr.Normalize(**self.settings['normalize_params'],
                       apply_to=['image']),
         ctr.Squeeze(apply_to=['label']),
     ])
     return sample_transforms(sample)
    def transform_val(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomCrop(crop_size=(512, 512)),
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
Esempio n. 5
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.FixedResize(resize=self.args.resize),
            tr.RandomCrop(crop_size=self.args.crop_size),
            #tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
            #tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
Esempio n. 6
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.scaleNorm(),
            tr.RandomScale((1.0, 1.4)),
            tr.RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25)),
            tr.RandomCrop(image_h, image_w),
            tr.RandomFlip(),
            tr.ToTensor(),
            tr.Normalize()
        ])

        return composed_transforms(sample)
Esempio n. 7
0
 def transform_tr(self, sample):
     composed_transforms = transforms.Compose([
         tr.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
         #tr.RandomGaussianBlur(),
         tr.HorizontalFlip(),
         tr.RandomScale(),
         tr.RandomCrop(size=(self.args.crop_size, self.args.crop_size)),
         tr.Normalize(mean=(0.485, 0.456, 0.406),
                      std=(0.229, 0.224, 0.225)),
         tr.ToTensor()
     ])
     return composed_transforms(sample)
Esempio n. 8
0
    def transform_tr(self, sample):
        temp = []
        if self.args.rotate > 0:
            temp.append(tr.RandomRotate(self.args.rotate))
        temp.append(tr.RandomScale(rand_resize=self.args.rand_resize))
        temp.append(tr.RandomCrop(self.args.input_size))
        temp.append(tr.RandomHorizontalFlip())
        temp.append(
            tr.Normalize(mean=self.args.normal_mean, std=self.args.normal_std))
        if self.args.noise_param is not None:
            temp.append(
                tr.GaussianNoise(mean=self.args.noise_param[0],
                                 std=self.args.noise_param[1]))
        temp.append(tr.ToTensor())
        composed_transforms = transforms.Compose(temp)

        return composed_transforms(sample)
Esempio n. 9
0
        for _validc in self.valid_classes:
            mask[mask == _validc] = self.class_map[_validc]
        return mask


if __name__ == '__main__':
    from dataloaders import custom_transforms as tr
    from dataloaders.utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt

    composed_transforms_tr = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.RandomScale((0.5, 0.75)),
        tr.RandomCrop((512, 1024)),
        tr.RandomRotate(5),
        tr.ToTensor()
    ])

    cityscapes_train = CityscapesSegmentation(split='train',
                                              transform=composed_transforms_tr)

    dataloader = DataLoader(cityscapes_train,
                            batch_size=2,
                            shuffle=True,
                            num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
Esempio n. 10
0
    def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False):
        ###################
        self.model.train()
        running_loss = AverageMeter()
        #optimizer = optim.SGD([{'params':self.model.feature_extracter.parameters()},{'params':self.model.semantic_embedding.parameters()},{'params':self.model.dynamic_seghead.parameters()}],lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM)
        optimizer = optim.SGD(self.model.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY)
        #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA)
        

        ###################

        composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        composed_transforms_ytb = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale([0.5,1,1.25]),
                                                     tr.RandomCrop((800,800)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        print('dataset processing...')
#        train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        ytb_train_dataset = YTB_VOS_Train(root=cfg.YTB_DATAROOT,transform=composed_transforms_ytb)
#        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
#                        sampler = RandomIdentitySampler(train_dataset.sample_list), 
#                        shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_davis = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_ytb = DataLoader(ytb_train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        #trainloader=[trainloader_ytb,trainloader_davis]
        trainloader=[trainloader_ytb,trainloader_davis]
        print('dataset processing finished.')
        if lossfunc=='bce':
            criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        elif lossfunc=='cross_entropy':
            criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        else:
            print('unsupported loss funciton. Please choose from [cross_entropy,bce]')

        max_itr = cfg.TRAIN_TOTAL_STEPS

        step=0

        if model_resume:
            saved_model_=os.path.join(self.save_res_dir,'save_step_60000.pth')

            saved_model_ = torch.load(saved_model_)
            self.model=self.load_network(self.model,saved_model_)
            step=60000
            print('resume from step {}'.format(step))
        while step<cfg.TRAIN_TOTAL_STEPS:

            for train_dataloader in trainloader:
 #       sample['meta']={'seq_name':seqname,'frame_num':frame_num,'obj_num':obj_num}
                for ii, sample in enumerate(train_dataloader):
        #            print(ii)
                    now_lr=self._adjust_lr(optimizer,step,max_itr)
                    ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                    img1s = sample['img1'] 
                    img2s = sample['img2']
                    ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                    label1s = sample['label1']
                    label2s = sample['label2']
                    seq_names = sample['meta']['seq_name'] 
                    obj_nums = sample['meta']['obj_num']
                    bs,_,h,w = img2s.size()
                    inputs = torch.cat((ref_imgs,img1s,img2s),0)
                    if damage_initial_previous_frame_mask:
                        try:
                            label1s = damage_masks(label1s)
                        except:
                            label1s = label1s
                            print('damage_error')




                    ##########
                    if self.use_gpu:
                        inputs = inputs.cuda()
                        ref_scribble_labels=ref_scribble_labels.cuda()
                        label1s = label1s.cuda()
                        label2s = label2s.cuda()
                     
                    ##########


                    tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS)

                    label_and_obj_dic={}
                    label_dic={}
                    for i, seq_ in enumerate(seq_names):
                        label_and_obj_dic[seq_]=(label2s[i],obj_nums[i])
                    for seq_ in tmp_dic.keys():
                        tmp_pred_logits = tmp_dic[seq_]
                        tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True)
                        tmp_dic[seq_]=tmp_pred_logits

                        label_tmp,obj_num = label_and_obj_dic[seq_]
                        obj_ids = np.arange(1,obj_num+1)
                        obj_ids = torch.from_numpy(obj_ids)
                        obj_ids = obj_ids.int()
                        if torch.cuda.is_available():
                            obj_ids = obj_ids.cuda()
                        if lossfunc == 'bce':
                            label_tmp = label_tmp.permute(1,2,0)
                            label = (label_tmp.float()==obj_ids.float())
                            label = label.unsqueeze(-1).permute(3,2,0,1)
                            label_dic[seq_]=label.float()
                        elif lossfunc =='cross_entropy':
                            label_dic[seq_]=label_tmp.long()



                    loss = criterion(tmp_dic,label_dic,step)
                    loss =loss/bs
######################################
                     
                    if loss.item()>10000:
                        print(tmp_dic)
                        for k,v in tmp_dic.items():
                            v = v.cpu()
                            v = v.detach().numpy()
                            np.save(k+'.npy',v)
                            l=label_dic[k]
                            l=l.cpu().detach().numpy()
                            np.save('lab'+k+'.npy',l)
                        #continue 
                        exit()
##########################################
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    #scheduler.step()
                    
                    running_loss.update(loss.item(),bs)
                    if step%1==0:
                        #print(torch.cuda.memory_allocated())
                        #print(torch.cuda.max_memory_cached())
                        #torch.cuda.empty_cache()
                        #torch.cuda.reset_max_memory_allocated()
                        print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg))
                        #print(tmp_dic)
                        #print(seq_names)
                    #    print('step:{}'.format(step))
                        
                        show_ref_img = ref_imgs.cpu().numpy()[0]
                        show_img1 = img1s.cpu().numpy()[0]
                        show_img2 = img2s.cpu().numpy()[0]

                        mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
                        sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])

                        show_ref_img = show_ref_img*sigma+mean
                        show_img1 = show_img1*sigma+mean
                        show_img2 = show_img2*sigma+mean


                        show_gt = label2s.cpu()[0]

                        show_gt = show_gt.squeeze(0).numpy()
                        show_gtf = label2colormap(show_gt).transpose((2,0,1))

                        ##########
                        show_preds = tmp_dic[seq_names[0]].cpu()
                        show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True)
                        show_preds = show_preds.squeeze(0)
                        if lossfunc=='bce':
                            show_preds = (torch.sigmoid(show_preds)>0.5)
                            show_preds_s = torch.zeros((h,w))
                            for i in range(show_preds.size(0)):
                                show_preds_s[show_preds[i]]=i+1
                        elif lossfunc=='cross_entropy':
                            show_preds_s = torch.argmax(show_preds,dim=0)
                        show_preds_s = show_preds_s.numpy()
                        show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1))

                        pix_acc = np.sum(show_preds_s==show_gt)/(h*w)




                        tblogger.add_scalar('loss', running_loss.avg, step)
                        tblogger.add_scalar('pix_acc', pix_acc, step)
                        tblogger.add_scalar('now_lr', now_lr, step)
                        tblogger.add_image('Reference image', show_ref_img, step)
                        tblogger.add_image('Previous frame image', show_img1, step)
                        tblogger.add_image('Current frame image', show_img2, step)
                        tblogger.add_image('Groud Truth', show_gtf, step)
                        tblogger.add_image('Predict label', show_preds_sf, step)




                        ###########TODO
                    if step%5000==0 and step!=0:
                        self.save_network(self.model,step)



                    step+=1
Esempio n. 11
0
def main():
    # Add default values to all parameters
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')
    parser.add_argument(
        '--coefficient', type=float, default=0.01, help='balance coefficient'
    )
    parser.add_argument(
        '--boundary-exist', type=bool, default=True, help='whether or not using boundary branch'
    )
    parser.add_argument(
        '--dataset', type=str, default='refuge', help='folder id contain images ROIs to train or validation'
    )
    parser.add_argument(
        '--batch-size', type=int, default=12, help='batch size for training the model'
    )
    # parser.add_argument(
    #     '--group-num', type=int, default=1, help='group number for group normalization'
    # )
    parser.add_argument(
        '--max-epoch', type=int, default=300, help='max epoch'
    )
    parser.add_argument(
        '--stop-epoch', type=int, default=300, help='stop epoch'
    )
    parser.add_argument(
        '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN'
    )
    parser.add_argument(
        '--interval-validate', type=int, default=1, help='interval epoch number to valide the model'
    )
    parser.add_argument(
        '--lr-gen', type=float, default=1e-3, help='learning rate',
    )
    parser.add_argument(
        '--lr-dis', type=float, default=2.5e-5, help='learning rate',
    )
    parser.add_argument(
        '--lr-decrease-rate', type=float, default=0.2, help='ratio multiplied to initial lr',
    )
    parser.add_argument(
        '--weight-decay', type=float, default=0.0005, help='weight decay',
    )
    parser.add_argument(
        '--momentum', type=float, default=0.9, help='momentum',
    )
    parser.add_argument(
        '--data-dir',
        default='./fundus/',
        help='data root path'
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=False,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )

    args = parser.parse_args()
    args.model = 'MobileNetV2'

    now = datetime.now()
    args.out = osp.join(here, 'logs', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f'))
    os.makedirs(args.out)

    # save training hyperparameters or/and settings
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(2020)
    if cuda:
        torch.cuda.manual_seed(2020)
    
    import random
    import numpy as np
    random.seed(2020)
    np.random.seed(2020)

    # 1. loading data
    composed_transforms_train = transforms.Compose([
        tr.RandomScaleCrop(512),
        tr.RandomRotate(),
        tr.RandomFlip(),
        tr.elastic_transform(),
        tr.add_salt_pepper_noise(),
        tr.adjust_light(),
        tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_val = transforms.Compose([
        tr.RandomCrop(512),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    data_train = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='train',
                                       transform=composed_transforms_train)
    dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                  pin_memory=True)
    data_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='testval',
                                     transform=composed_transforms_val)
    dataloader_val = DataLoader(data_val, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
    # domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train',
    #                                    transform=composed_transforms_ts)
    # domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2,
    #                                pin_memory=True)

    # 2. model
    model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride,
                        sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda()

    model_bd = BoundaryDiscriminator().cuda()
    model_mask = MaskDiscriminator().cuda()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer
    optim_gen = torch.optim.Adam(
        model_gen.parameters(),
        lr=args.lr_gen,
        betas=(0.9, 0.99)
    )
    optim_bd = torch.optim.SGD(
        model_bd.parameters(),
        lr=args.lr_dis,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    optim_mask = torch.optim.SGD(
        model_mask.parameters(),
        lr=args.lr_dis,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # breakpoint recovery
    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model_gen.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model_gen.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_bd_state_dict']
        model_dict = model_bd.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model_bd.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_mask_state_dict']
        model_dict = model_mask.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model_mask.load_state_dict(model_dict)

        start_epoch = checkpoint['epoch'] + 1
        start_iteration = checkpoint['iteration'] + 1
        optim_gen.load_state_dict(checkpoint['optim_state_dict'])
        optim_bd.load_state_dict(checkpoint['optim_bd_state_dict'])
        optim_mask.load_state_dict(checkpoint['optim_mask_state_dict'])

    trainer = Trainer.Trainer(
        cuda=cuda,
        model_gen=model_gen,
        model_bd=model_bd,
        model_mask=model_mask,
        optimizer_gen=optim_gen,
        optim_bd=optim_bd,
        optim_mask=optim_mask,
        lr_gen=args.lr_gen,
        lr_dis=args.lr_dis,
        lr_decrease_rate=args.lr_decrease_rate,
        train_loader=dataloader_train,
        validation_loader=dataloader_val,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
        warmup_epoch=args.warmup_epoch,
        coefficient=args.coefficient,
        boundary_exist=args.boundary_exist
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
Esempio n. 12
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')

    # configurations (same configuration as original work)
    # https://github.com/shelhamer/fcn.berkeleyvision.org
    parser.add_argument('--datasetS',
                        type=str,
                        default='refuge',
                        help='test folder id contain images ROIs to test')
    parser.add_argument('--datasetT',
                        type=str,
                        default='Drishti-GS',
                        help='refuge / Drishti-GS/ RIM-ONE_r3')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        help='batch size for training the model')
    parser.add_argument('--group-num',
                        type=int,
                        default=1,
                        help='group number for group normalization')
    parser.add_argument('--max-epoch', type=int, default=200, help='max epoch')
    parser.add_argument('--stop-epoch',
                        type=int,
                        default=200,
                        help='stop epoch')
    parser.add_argument('--warmup-epoch',
                        type=int,
                        default=-1,
                        help='warmup epoch begin train GAN')

    parser.add_argument('--interval-validate',
                        type=int,
                        default=10,
                        help='interval epoch number to valide the model')
    parser.add_argument(
        '--lr-gen',
        type=float,
        default=1e-3,
        help='learning rate',
    )
    parser.add_argument(
        '--lr-dis',
        type=float,
        default=2.5e-5,
        help='learning rate',
    )
    parser.add_argument(
        '--lr-decrease-rate',
        type=float,
        default=0.1,
        help='ratio multiplied to initial lr',
    )
    parser.add_argument(
        '--weight-decay',
        type=float,
        default=0.0005,
        help='weight decay',
    )
    parser.add_argument(
        '--momentum',
        type=float,
        default=0.99,
        help='momentum',
    )
    parser.add_argument('--data-dir',
                        default='/home/sjwang/ssd1T/fundus/domain_adaptation/',
                        help='data root path')
    parser.add_argument(
        '--pretrained-model',
        default='../../../models/pytorch/fcn16s_from_caffe.pth',
        help='pretrained model of FCN16s',
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=True,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )

    args = parser.parse_args()

    args.model = 'FCN8s'

    now = datetime.now()
    args.out = osp.join(here, 'logs', args.datasetT,
                        now.strftime('%Y%m%d_%H%M%S.%f'))

    os.makedirs(args.out)
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset
    composed_transforms_tr = transforms.Compose([
        tr.RandomScaleCrop(512),
        tr.RandomRotate(),
        tr.RandomFlip(),
        tr.elastic_transform(),
        tr.add_salt_pepper_noise(),
        tr.adjust_light(),
        tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.RandomCrop(512),
         tr.Normalize_tf(),
         tr.ToTensor()])

    domain = DL.FundusSegmentation(base_dir=args.data_dir,
                                   dataset=args.datasetS,
                                   split='train',
                                   transform=composed_transforms_tr)
    domain_loaderS = DataLoader(domain,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=2,
                                pin_memory=True)
    domain_T = DL.FundusSegmentation(base_dir=args.data_dir,
                                     dataset=args.datasetT,
                                     split='train',
                                     transform=composed_transforms_tr)
    domain_loaderT = DataLoader(domain_T,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=2,
                                pin_memory=True)
    domain_val = DL.FundusSegmentation(base_dir=args.data_dir,
                                       dataset=args.datasetT,
                                       split='train',
                                       transform=composed_transforms_ts)
    domain_loader_val = DataLoader(domain_val,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=2,
                                   pin_memory=True)

    # 2. model
    model_gen = DeepLab(num_classes=2,
                        backbone='mobilenet',
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn).cuda()

    model_dis = BoundaryDiscriminator().cuda()
    model_dis2 = UncertaintyDiscriminator().cuda()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer

    optim_gen = torch.optim.Adam(model_gen.parameters(),
                                 lr=args.lr_gen,
                                 betas=(0.9, 0.99))
    optim_dis = torch.optim.SGD(model_dis.parameters(),
                                lr=args.lr_dis,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    optim_dis2 = torch.optim.SGD(model_dis2.parameters(),
                                 lr=args.lr_dis,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)

    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model_gen.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model_gen.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_dis_state_dict']
        model_dict = model_dis.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_dis.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_dis2_state_dict']
        model_dict = model_dis2.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_dis2.load_state_dict(model_dict)

        start_epoch = checkpoint['epoch'] + 1
        start_iteration = checkpoint['iteration'] + 1
        optim_gen.load_state_dict(checkpoint['optim_state_dict'])
        optim_dis.load_state_dict(checkpoint['optim_dis_state_dict'])
        optim_dis2.load_state_dict(checkpoint['optim_dis2_state_dict'])
        optim_adv.load_state_dict(checkpoint['optim_adv_state_dict'])

    trainer = Trainer.Trainer(
        cuda=cuda,
        model_gen=model_gen,
        model_dis=model_dis,
        model_uncertainty_dis=model_dis2,
        optimizer_gen=optim_gen,
        optimizer_dis=optim_dis,
        optimizer_uncertainty_dis=optim_dis2,
        lr_gen=args.lr_gen,
        lr_dis=args.lr_dis,
        lr_decrease_rate=args.lr_decrease_rate,
        val_loader=domain_loader_val,
        domain_loaderS=domain_loaderS,
        domain_loaderT=domain_loaderT,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
        warmup_epoch=args.warmup_epoch,
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
Esempio n. 13
0
                 zero_pad=args.zero_pad_crop,
                 adaptive_relax=args.adaptive_relax,
                 prefix=''),
 tr.Resize(resize_elems=['image', 'gt', 'thin', 'void_pixels'],
           min_size=args.min_size,
           max_size=args.max_size),
 tr.ComputeImageGradient(elem='image'),
 tr.ExtremePoints(sigma=10, pert=5, elem='gt'),
 tr.GaussianTransform(tr_elems=['extreme_points'],
                      mask_elem='gt',
                      sigma=10,
                      tr_name='points'),
 tr.RandomCrop(
     num_thin=args.num_thin_samples,
     num_non_thin=args.num_non_thin_samples,
     crop_size=args.roi_size,
     prefix='crop_',
     thin_elem='thin',
     crop_elems=['image', 'gt', 'points', 'void_pixels', 'image_grad']),
 tr.MatchROIs(crop_elem='gt', resolution=args.lr_size),
 tr.FixedResizePoints(
     resolutions={'extreme_points': (args.lr_size, args.lr_size)},
     mask_elem='gt',
     prefix='lr_'),
 tr.FixedResize(resolutions={
     'image': (args.lr_size, args.lr_size),
     'gt': (args.lr_size, args.lr_size),
     'void_pixels': (args.lr_size, args.lr_size)
 },
                prefix='lr_'),
 tr.GaussianTransform(tr_elems=['lr_extreme_points'],
Esempio n. 14
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')

    parser.add_argument(
        '--datasetTrain',
        nargs='+',
        type=int,
        default=1,
        help='train folder id contain images ROIs to train range from [1,2,3,4]'
    )
    parser.add_argument(
        '--datasetTest',
        nargs='+',
        type=int,
        default=1,
        help='test folder id contain images ROIs to test one of [1,2,3,4]')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        help='batch size for training the model')
    parser.add_argument('--group-num',
                        type=int,
                        default=1,
                        help='group number for group normalization')
    parser.add_argument('--max-epoch', type=int, default=120, help='max epoch')
    parser.add_argument('--stop-epoch',
                        type=int,
                        default=80,
                        help='stop epoch')
    parser.add_argument('--interval-validate',
                        type=int,
                        default=10,
                        help='interval epoch number to valide the model')
    parser.add_argument(
        '--lr',
        type=float,
        default=1e-3,
        help='learning rate',
    )
    parser.add_argument('--lr-decrease-rate',
                        type=float,
                        default=0.2,
                        help='ratio multiplied to initial lr')
    parser.add_argument(
        '--lam',
        type=float,
        default=0.9,
        help='momentum of memory update',
    )
    parser.add_argument('--data-dir',
                        default='../../../../Dataset/Fundus/',
                        help='data root path')
    parser.add_argument(
        '--pretrained-model',
        default='../../../models/pytorch/fcn16s_from_caffe.pth',
        help='pretrained model of FCN16s',
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    args = parser.parse_args()

    now = datetime.now()
    args.out = osp.join(local_path, 'logs', 'test' + str(args.datasetTest[0]),
                        'lam' + str(args.lam),
                        now.strftime('%Y%m%d_%H%M%S.%f'))
    os.makedirs(args.out)
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()
    torch.cuda.manual_seed(1337)

    # 1. dataset
    composed_transforms_tr = transforms.Compose([
        tr.RandomScaleCrop(256),
        # tr.RandomCrop(512),
        # tr.RandomRotate(),
        # tr.RandomFlip(),
        # tr.elastic_transform(),
        # tr.add_salt_pepper_noise(),
        # tr.adjust_light(),
        # tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.RandomCrop(256),
         tr.Normalize_tf(),
         tr.ToTensor()])

    domain = DL.FundusSegmentation(base_dir=args.data_dir,
                                   phase='train',
                                   splitid=args.datasetTrain,
                                   transform=composed_transforms_tr)
    train_loader = DataLoader(domain,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)

    domain_val = DL.FundusSegmentation(base_dir=args.data_dir,
                                       phase='test',
                                       splitid=args.datasetTest,
                                       transform=composed_transforms_ts)
    val_loader = DataLoader(domain_val,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=2,
                            pin_memory=True)

    # 2. model
    model = DeepLab(num_classes=2,
                    num_domain=3,
                    backbone='mobilenet',
                    output_stride=args.out_stride,
                    lam=args.lam).cuda()
    print('parameter numer:', sum([p.numel() for p in model.parameters()]))

    # load weights
    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)

        print('Before ', model.centroids.data)
        model.centroids.data = centroids_init(model, args.data_dir,
                                              args.datasetTrain,
                                              composed_transforms_ts)
        print('Before ', model.centroids.data)
        # model.freeze_para()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer
    optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99))

    trainer = Trainer.Trainer(
        cuda=cuda,
        model=model,
        lr=args.lr,
        lr_decrease_rate=args.lr_decrease_rate,
        train_loader=train_loader,
        val_loader=val_loader,
        optim=optim,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()