Beispiel #1
0
    def transform_tr(self, sample):
        """Transformations for images
        sample: {image:img, annotation:ann}

        Note: the mean and std is from imagenet
        """
        if self.args.no_flip:
            composed_transforms = transforms.Compose([
                tr.RandomScaleCrop(base_size=self.args.base_size,
                                   crop_size=self.args.crop_size,
                                   scale_ratio=self.args.scale_ratio,
                                   fill=0),
                tr.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
                tr.ToTensor()
            ])
            return composed_transforms(sample)
        else:
            composed_transforms = transforms.Compose([
                tr.RandomHorizontalFlip(),
                tr.RandomScaleCrop(base_size=self.args.base_size,
                                   crop_size=self.args.crop_size,
                                   scale_ratio=self.args.scale_ratio,
                                   fill=0),
                tr.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
                tr.ToTensor()
            ])
            return composed_transforms(sample)
Beispiel #2
0
 def transform_tr(self, sample):
     # if (sample['image'].width>self.args.base_size*2) and (sample['image'].height>self.args.base_size*2):
     #     composed_transforms = transforms.Compose([
     #         tr.RandomHorizontalFlip(),
     #         tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
     #         tr.RandomGaussianBlur(),
     #         tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
     #         tr.ToTensor()])
     # else:
     #     composed_transforms = transforms.Compose([
     #         # tr.FixScaleCrop(crop_size=self.args.crop_size),
     #         tr.RandomHorizontalFlip(),
     #         tr.RandomGaussianBlur(),
     #         tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
     #         tr.ToTensor()])
     composed_transforms = transforms.Compose([
         tr.RandomHorizontalFlip(),
         tr.RandomScaleCrop(base_size=self.args.base_size,
                            crop_size=self.args.crop_size),
         tr.RandomGaussianBlur(),
         tr.Normalize(mean=(0.485, 0.456, 0.406),
                      std=(0.229, 0.224, 0.225)),
         tr.ToTensor()
     ])
     return composed_transforms(sample)
Beispiel #3
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)
Beispiel #4
0
    def transform_tr(self, sample):
        if random.random() > 0.5:
            if random.random() > 0.5:
                tr_function = tr.FixScaleCrop
            else:
                tr_function = tr.FixedResize

            composed_transforms = transforms.Compose(
                [
                    tr_function(self.args.crop_size),
                    tr.RandomGaussianBlur(),
                    tr.Normalize(
                        mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
                    ),
                    tr.ToTensor(),
                ]
            )
        else:
            composed_transforms = transforms.Compose(
                [
                    tr.RandomScaleCrop(
                        base_size=self.args.base_size,
                        crop_size=self.args.crop_size,
                        fill=255,
                    ),
                    tr.RandomGaussianBlur(),
                    tr.Normalize(
                        mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
                    ),
                    tr.ToTensor(),
                ]
            )
        return composed_transforms(sample)
 def transform_tr(self, sample):
      composed_transforms = transforms.Compose([
          tr.RandomHorizontalFlip(),
          tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
          tr.RandomGaussianBlur(),
          tr.Normalize(mean=(0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)),
          tr.ToTensor()])
Beispiel #6
0
    def transform_tr(self, sample): # eventually, according to the condition of split in self.split, then split == 'train'
        composed_transforms = transforms.Compose([     # define transform_tr
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), # random scale crop, we have to calcualte base_size and crop_size based on argparse
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)  # return composed_transforms
Beispiel #7
0
    def transform(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.cfg.DATASET.BASE_SIZE, crop_size=self.cfg.DATASET.CROP_SIZE),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)
Beispiel #8
0
 def transform_tr(self, sample):
     composed_transforms = transforms.Compose([
         tr.FixedResize(size=(1024, 2048)),
         tr.ColorJitter(),
         tr.RandomGaussianBlur(),
         tr.RandomMotionBlur(),
         tr.RandomHorizontalFlip(),
         tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
         tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
         tr.ToTensor()])
     return composed_transforms(sample)
Beispiel #9
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size),
            tr.RandomGaussianBlur(),
            tr.Resize_normalize_train(mean=(0.5, 0.5, 0.5),
                                      std=(0.5, 0.5, 0.5))
        ])

        return composed_transforms(sample)
Beispiel #10
0
    def transform_tr_part1_1(self, sample):
        if self.args.use_small:
            composed_transforms = transforms.Compose(
                [tr.FixScaleCrop(crop_size=self.args.crop_size)])
        else:
            composed_transforms = transforms.Compose([
                tr.RandomHorizontalFlip(),
                tr.RandomScaleCrop(base_size=self.args.base_size,
                                   crop_size=self.args.crop_size)
            ])  # Zhiwei

        return composed_transforms(sample)
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.base_size,
                               crop_size=self.crop_size,
                               fill=255),
            tr.RandomDarken(self.cfg, self.darken),
            #tr.RandomGaussianBlur(), #TODO Not working for depth channel
            tr.Normalize(mean=self.data_mean, std=self.data_std),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
Beispiel #12
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(
            ),  # given PIL image randomly with a given probability
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
    def transform_tr(self, sample):
        """Image transformations for training"""
        targs = self.transform
        method = targs["method"]
        pars = targs["parameters"]
        composed_transforms = transforms.Compose([
            tr.FixedResize(size=pars["outSize"]),
            tr.RandomRotate(degree=(90)),
            tr.RandomScaleCrop(baseSize=pars["baseSize"], cropSize=pars["outSize"], fill=255),
            tr.Normalize(mean=pars["mean"], std=pars["std"]),
            tr.ToTensor()])

        return composed_transforms(sample)
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            #tr.RandomHorizontalFlip(),
            tr.RandomRotate(degree=random.randint(15, 350)),
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size,
                               fill=255),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
Beispiel #15
0
 def transform_tr(self, sample):
     composed_transforms = transforms.Compose([
         tr.RandomHorizontalFlip(),
         tr.RandomScaleCrop(base_size=513, crop_size=513),
         tr.ColorJitter(brightness=0.3,
                        contrast=0.3,
                        saturation=0.3,
                        hue=0.3,
                        gamma=0.3),
         tr.Normalize(mean=(0.485, 0.456, 0.406),
                      std=(0.229, 0.224, 0.225)),
         tr.ToTensor()
     ])
     return composed_transforms(sample)
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size),
            # tr.RandomGaussianBlur(),
            # tr.FixedResize(self.args.crop_size),
            # tr.RandomCrop(self.args.crop_size),
            # tr.RandomCutout(n_holes=1, cut_size=128),
            # tr.RandomRotate(30),
            tr.RandomRotate_v2(),
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
    def transform_tr(self, sample):
        '''
        Transform the given training sample.
        
        @param sample: The given training sample.
        '''
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size,
                               fill=255),
            tr.RandomGaussianBlur(),
            tf.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
    def transform_tr(self, sample):

        color_transforms = [
            transforms.RandomApply([transforms.ColorJitter(brightness=0.1)
                                    ]),  # brightness
            transforms.RandomApply([transforms.ColorJitter(contrast=0.1)
                                    ]),  # contrast
            transforms.RandomApply([transforms.ColorJitter(saturation=0.1)
                                    ]),  # saturation
            transforms.RandomApply([transforms.ColorJitter(hue=0.05)])
        ]  # hue

        joint_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size,
                               fill=255),
            tr.equalize(),
            tr.RandomGaussianBlur(),
            tr.RandomRotate(degree=7)
        ])

        image_transforms = transforms.Compose([
            transforms.RandomOrder(color_transforms),
            transforms.RandomGrayscale(p=0.3)
        ])

        normalize_transforms = transforms.Compose([
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])

        tmp_sample = joint_transforms(sample)
        tmp_sample['image'] = image_transforms(tmp_sample['image'])
        tmp_sample = normalize_transforms(tmp_sample)

        return tmp_sample
Beispiel #19
0
    def transform_tr(self, sample):
        """
        composed transformers for training dataset
        :param sample: {'image': image, 'label': label}
        :return:
        """
        img = sample['image']
        img = transforms.ColorJitter(brightness=0.5,
                                     contrast=0.5,
                                     saturation=0.5,
                                     hue=0.2)(img)
        sample = {'image': img, 'label': sample['label']}
        composed_transforms = transforms.Compose([
            ct.RandomHorizontalFlip(),
            ct.RandomScaleCrop(base_size=self.base_size,
                               crop_size=self.crop_size),
            # ct.RandomChangeBackground(),
            ct.RandomGaussianBlur(),
            ct.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            ct.ToTensor()
        ])

        return composed_transforms(sample)
Beispiel #20
0
def main():
    # Add default values to all parameters
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')
    parser.add_argument(
        '--coefficient', type=float, default=0.01, help='balance coefficient'
    )
    parser.add_argument(
        '--boundary-exist', type=bool, default=True, help='whether or not using boundary branch'
    )
    parser.add_argument(
        '--dataset', type=str, default='refuge', help='folder id contain images ROIs to train or validation'
    )
    parser.add_argument(
        '--batch-size', type=int, default=12, help='batch size for training the model'
    )
    # parser.add_argument(
    #     '--group-num', type=int, default=1, help='group number for group normalization'
    # )
    parser.add_argument(
        '--max-epoch', type=int, default=300, help='max epoch'
    )
    parser.add_argument(
        '--stop-epoch', type=int, default=300, help='stop epoch'
    )
    parser.add_argument(
        '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN'
    )
    parser.add_argument(
        '--interval-validate', type=int, default=1, help='interval epoch number to valide the model'
    )
    parser.add_argument(
        '--lr-gen', type=float, default=1e-3, help='learning rate',
    )
    parser.add_argument(
        '--lr-dis', type=float, default=2.5e-5, help='learning rate',
    )
    parser.add_argument(
        '--lr-decrease-rate', type=float, default=0.2, help='ratio multiplied to initial lr',
    )
    parser.add_argument(
        '--weight-decay', type=float, default=0.0005, help='weight decay',
    )
    parser.add_argument(
        '--momentum', type=float, default=0.9, help='momentum',
    )
    parser.add_argument(
        '--data-dir',
        default='./fundus/',
        help='data root path'
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=False,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )

    args = parser.parse_args()
    args.model = 'MobileNetV2'

    now = datetime.now()
    args.out = osp.join(here, 'logs', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f'))
    os.makedirs(args.out)

    # save training hyperparameters or/and settings
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(2020)
    if cuda:
        torch.cuda.manual_seed(2020)
    
    import random
    import numpy as np
    random.seed(2020)
    np.random.seed(2020)

    # 1. loading data
    composed_transforms_train = transforms.Compose([
        tr.RandomScaleCrop(512),
        tr.RandomRotate(),
        tr.RandomFlip(),
        tr.elastic_transform(),
        tr.add_salt_pepper_noise(),
        tr.adjust_light(),
        tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_val = transforms.Compose([
        tr.RandomCrop(512),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    data_train = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='train',
                                       transform=composed_transforms_train)
    dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                  pin_memory=True)
    data_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='testval',
                                     transform=composed_transforms_val)
    dataloader_val = DataLoader(data_val, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
    # domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train',
    #                                    transform=composed_transforms_ts)
    # domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2,
    #                                pin_memory=True)

    # 2. model
    model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride,
                        sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda()

    model_bd = BoundaryDiscriminator().cuda()
    model_mask = MaskDiscriminator().cuda()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer
    optim_gen = torch.optim.Adam(
        model_gen.parameters(),
        lr=args.lr_gen,
        betas=(0.9, 0.99)
    )
    optim_bd = torch.optim.SGD(
        model_bd.parameters(),
        lr=args.lr_dis,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    optim_mask = torch.optim.SGD(
        model_mask.parameters(),
        lr=args.lr_dis,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # breakpoint recovery
    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model_gen.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model_gen.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_bd_state_dict']
        model_dict = model_bd.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model_bd.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_mask_state_dict']
        model_dict = model_mask.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model_mask.load_state_dict(model_dict)

        start_epoch = checkpoint['epoch'] + 1
        start_iteration = checkpoint['iteration'] + 1
        optim_gen.load_state_dict(checkpoint['optim_state_dict'])
        optim_bd.load_state_dict(checkpoint['optim_bd_state_dict'])
        optim_mask.load_state_dict(checkpoint['optim_mask_state_dict'])

    trainer = Trainer.Trainer(
        cuda=cuda,
        model_gen=model_gen,
        model_bd=model_bd,
        model_mask=model_mask,
        optimizer_gen=optim_gen,
        optim_bd=optim_bd,
        optim_mask=optim_mask,
        lr_gen=args.lr_gen,
        lr_dis=args.lr_dis,
        lr_decrease_rate=args.lr_decrease_rate,
        train_loader=dataloader_train,
        validation_loader=dataloader_val,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
        warmup_epoch=args.warmup_epoch,
        coefficient=args.coefficient,
        boundary_exist=args.boundary_exist
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
Beispiel #21
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')

    parser.add_argument(
        '--datasetTrain',
        nargs='+',
        type=int,
        default=1,
        help='train folder id contain images ROIs to train range from [1,2,3,4]'
    )
    parser.add_argument(
        '--datasetTest',
        nargs='+',
        type=int,
        default=1,
        help='test folder id contain images ROIs to test one of [1,2,3,4]')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        help='batch size for training the model')
    parser.add_argument('--group-num',
                        type=int,
                        default=1,
                        help='group number for group normalization')
    parser.add_argument('--max-epoch', type=int, default=120, help='max epoch')
    parser.add_argument('--stop-epoch',
                        type=int,
                        default=80,
                        help='stop epoch')
    parser.add_argument('--interval-validate',
                        type=int,
                        default=10,
                        help='interval epoch number to valide the model')
    parser.add_argument(
        '--lr',
        type=float,
        default=1e-3,
        help='learning rate',
    )
    parser.add_argument('--lr-decrease-rate',
                        type=float,
                        default=0.2,
                        help='ratio multiplied to initial lr')
    parser.add_argument(
        '--lam',
        type=float,
        default=0.9,
        help='momentum of memory update',
    )
    parser.add_argument('--data-dir',
                        default='../../../../Dataset/Fundus/',
                        help='data root path')
    parser.add_argument(
        '--pretrained-model',
        default='../../../models/pytorch/fcn16s_from_caffe.pth',
        help='pretrained model of FCN16s',
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    args = parser.parse_args()

    now = datetime.now()
    args.out = osp.join(local_path, 'logs', 'test' + str(args.datasetTest[0]),
                        'lam' + str(args.lam),
                        now.strftime('%Y%m%d_%H%M%S.%f'))
    os.makedirs(args.out)
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()
    torch.cuda.manual_seed(1337)

    # 1. dataset
    composed_transforms_tr = transforms.Compose([
        tr.RandomScaleCrop(256),
        # tr.RandomCrop(512),
        # tr.RandomRotate(),
        # tr.RandomFlip(),
        # tr.elastic_transform(),
        # tr.add_salt_pepper_noise(),
        # tr.adjust_light(),
        # tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.RandomCrop(256),
         tr.Normalize_tf(),
         tr.ToTensor()])

    domain = DL.FundusSegmentation(base_dir=args.data_dir,
                                   phase='train',
                                   splitid=args.datasetTrain,
                                   transform=composed_transforms_tr)
    train_loader = DataLoader(domain,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)

    domain_val = DL.FundusSegmentation(base_dir=args.data_dir,
                                       phase='test',
                                       splitid=args.datasetTest,
                                       transform=composed_transforms_ts)
    val_loader = DataLoader(domain_val,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=2,
                            pin_memory=True)

    # 2. model
    model = DeepLab(num_classes=2,
                    num_domain=3,
                    backbone='mobilenet',
                    output_stride=args.out_stride,
                    lam=args.lam).cuda()
    print('parameter numer:', sum([p.numel() for p in model.parameters()]))

    # load weights
    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)

        print('Before ', model.centroids.data)
        model.centroids.data = centroids_init(model, args.data_dir,
                                              args.datasetTrain,
                                              composed_transforms_ts)
        print('Before ', model.centroids.data)
        # model.freeze_para()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer
    optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99))

    trainer = Trainer.Trainer(
        cuda=cuda,
        model=model,
        lr=args.lr,
        lr_decrease_rate=args.lr_decrease_rate,
        train_loader=train_loader,
        val_loader=val_loader,
        optim=optim,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
Beispiel #22
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')

    # configurations (same configuration as original work)
    # https://github.com/shelhamer/fcn.berkeleyvision.org
    parser.add_argument('--datasetS',
                        type=str,
                        default='refuge',
                        help='test folder id contain images ROIs to test')
    parser.add_argument('--datasetT',
                        type=str,
                        default='Drishti-GS',
                        help='refuge / Drishti-GS/ RIM-ONE_r3')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        help='batch size for training the model')
    parser.add_argument('--group-num',
                        type=int,
                        default=1,
                        help='group number for group normalization')
    parser.add_argument('--max-epoch', type=int, default=200, help='max epoch')
    parser.add_argument('--stop-epoch',
                        type=int,
                        default=200,
                        help='stop epoch')
    parser.add_argument('--warmup-epoch',
                        type=int,
                        default=-1,
                        help='warmup epoch begin train GAN')

    parser.add_argument('--interval-validate',
                        type=int,
                        default=10,
                        help='interval epoch number to valide the model')
    parser.add_argument(
        '--lr-gen',
        type=float,
        default=1e-3,
        help='learning rate',
    )
    parser.add_argument(
        '--lr-dis',
        type=float,
        default=2.5e-5,
        help='learning rate',
    )
    parser.add_argument(
        '--lr-decrease-rate',
        type=float,
        default=0.1,
        help='ratio multiplied to initial lr',
    )
    parser.add_argument(
        '--weight-decay',
        type=float,
        default=0.0005,
        help='weight decay',
    )
    parser.add_argument(
        '--momentum',
        type=float,
        default=0.99,
        help='momentum',
    )
    parser.add_argument('--data-dir',
                        default='/home/sjwang/ssd1T/fundus/domain_adaptation/',
                        help='data root path')
    parser.add_argument(
        '--pretrained-model',
        default='../../../models/pytorch/fcn16s_from_caffe.pth',
        help='pretrained model of FCN16s',
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=True,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )

    args = parser.parse_args()

    args.model = 'FCN8s'

    now = datetime.now()
    args.out = osp.join(here, 'logs', args.datasetT,
                        now.strftime('%Y%m%d_%H%M%S.%f'))

    os.makedirs(args.out)
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset
    composed_transforms_tr = transforms.Compose([
        tr.RandomScaleCrop(512),
        tr.RandomRotate(),
        tr.RandomFlip(),
        tr.elastic_transform(),
        tr.add_salt_pepper_noise(),
        tr.adjust_light(),
        tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.RandomCrop(512),
         tr.Normalize_tf(),
         tr.ToTensor()])

    domain = DL.FundusSegmentation(base_dir=args.data_dir,
                                   dataset=args.datasetS,
                                   split='train',
                                   transform=composed_transforms_tr)
    domain_loaderS = DataLoader(domain,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=2,
                                pin_memory=True)
    domain_T = DL.FundusSegmentation(base_dir=args.data_dir,
                                     dataset=args.datasetT,
                                     split='train',
                                     transform=composed_transforms_tr)
    domain_loaderT = DataLoader(domain_T,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=2,
                                pin_memory=True)
    domain_val = DL.FundusSegmentation(base_dir=args.data_dir,
                                       dataset=args.datasetT,
                                       split='train',
                                       transform=composed_transforms_ts)
    domain_loader_val = DataLoader(domain_val,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=2,
                                   pin_memory=True)

    # 2. model
    model_gen = DeepLab(num_classes=2,
                        backbone='mobilenet',
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn).cuda()

    model_dis = BoundaryDiscriminator().cuda()
    model_dis2 = UncertaintyDiscriminator().cuda()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer

    optim_gen = torch.optim.Adam(model_gen.parameters(),
                                 lr=args.lr_gen,
                                 betas=(0.9, 0.99))
    optim_dis = torch.optim.SGD(model_dis.parameters(),
                                lr=args.lr_dis,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    optim_dis2 = torch.optim.SGD(model_dis2.parameters(),
                                 lr=args.lr_dis,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)

    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model_gen.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model_gen.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_dis_state_dict']
        model_dict = model_dis.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_dis.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_dis2_state_dict']
        model_dict = model_dis2.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_dis2.load_state_dict(model_dict)

        start_epoch = checkpoint['epoch'] + 1
        start_iteration = checkpoint['iteration'] + 1
        optim_gen.load_state_dict(checkpoint['optim_state_dict'])
        optim_dis.load_state_dict(checkpoint['optim_dis_state_dict'])
        optim_dis2.load_state_dict(checkpoint['optim_dis2_state_dict'])
        optim_adv.load_state_dict(checkpoint['optim_adv_state_dict'])

    trainer = Trainer.Trainer(
        cuda=cuda,
        model_gen=model_gen,
        model_dis=model_dis,
        model_uncertainty_dis=model_dis2,
        optimizer_gen=optim_gen,
        optimizer_dis=optim_dis,
        optimizer_uncertainty_dis=optim_dis2,
        lr_gen=args.lr_gen,
        lr_dis=args.lr_dis,
        lr_decrease_rate=args.lr_decrease_rate,
        val_loader=domain_loader_val,
        domain_loaderS=domain_loaderS,
        domain_loaderT=domain_loaderT,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
        warmup_epoch=args.warmup_epoch,
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()