def transform_train(self):
     temp = []
     temp.append(tr.Resize(self.args.input_size))
     temp.append(tr.RandomHorizontalFlip())
     temp.append(tr.RandomRotate(15))
     temp.append(tr.RandomCrop(self.args.input_size))
     temp.append(tr.ToTensor())
     composed_transforms = transforms.Compose(temp)
     return composed_transforms
예제 #2
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
            tr.RandomGaussianBlur(),
            tr.RandomRotate(20),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)
    def transform_tr(self, sample):
        """Image transformations for training"""
        targs = self.transform
        method = targs["method"]
        pars = targs["parameters"]
        composed_transforms = transforms.Compose([
            tr.FixedResize(size=pars["outSize"]),
            tr.RandomRotate(degree=(90)),
            tr.RandomScaleCrop(baseSize=pars["baseSize"], cropSize=pars["outSize"], fill=255),
            tr.Normalize(mean=pars["mean"], std=pars["std"]),
            tr.ToTensor()])

        return composed_transforms(sample)
예제 #4
0
    def transform_train(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomVerticalFlip(),
            # tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
            # tr.FixedResize(size=self.args.crop_size),
            tr.RandomRotate(),
            tr.RandomGammaTransform(),
            tr.RandomGaussianBlur(),
            tr.RandomNoise(),
            tr.Normalize(mean=(0.544650, 0.352033, 0.384602, 0.352311), std=(0.249456, 0.241652, 0.228824, 0.227583)),
            tr.ToTensor()])

        return composed_transforms(sample)
예제 #5
0
    def transform_tr(self, sample):
        temp = []
        if self.args.rotate > 0:
            temp.append(tr.RandomRotate(self.args.rotate))
        temp.append(tr.RandomScale(rand_resize=self.args.rand_resize))
        temp.append(tr.RandomCrop(self.args.input_size))
        temp.append(tr.RandomHorizontalFlip())
        temp.append(
            tr.Normalize(mean=self.args.normal_mean, std=self.args.normal_std))
        if self.args.noise_param is not None:
            temp.append(
                tr.GaussianNoise(mean=self.args.noise_param[0],
                                 std=self.args.noise_param[1]))
        temp.append(tr.ToTensor())
        composed_transforms = transforms.Compose(temp)

        return composed_transforms(sample)
예제 #6
0
    def transform_tr(self, sample):

        color_transforms = [
            transforms.RandomApply([transforms.ColorJitter(brightness=0.1)
                                    ]),  # brightness
            transforms.RandomApply([transforms.ColorJitter(contrast=0.1)
                                    ]),  # contrast
            transforms.RandomApply([transforms.ColorJitter(saturation=0.1)
                                    ]),  # saturation
            transforms.RandomApply([transforms.ColorJitter(hue=0.05)])
        ]  # hue

        joint_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size,
                               fill=255),
            tr.equalize(),
            tr.RandomGaussianBlur(),
            tr.RandomRotate(degree=7)
        ])

        image_transforms = transforms.Compose([
            transforms.RandomOrder(color_transforms),
            transforms.RandomGrayscale(p=0.3)
        ])

        normalize_transforms = transforms.Compose([
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])

        tmp_sample = joint_transforms(sample)
        tmp_sample['image'] = image_transforms(tmp_sample['image'])
        tmp_sample = normalize_transforms(tmp_sample)

        return tmp_sample
    net.cuda()
    net = torch.nn.DataParallel(net,
                                device_ids=range(torch.cuda.device_count()))

if resume_epoch != nEpochs:
    # Logging into Tensorboard
    log_dir = os.path.join(
        save_dir, 'models',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())

    optimizer = optim.Adam(net.parameters(), lr=p['lr'], weight_decay=p['wd'])
    p['optimizer'] = str(optimizer)

    composed_transforms_tr = transforms.Compose([
        tr.RandomSized(512),
        tr.RandomRotate(15),
        tr.RandomHorizontalFlip(),
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose([
        tr.FixedResize(size=(512, 512)),
        tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        tr.ToTensor()
    ])

    # voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
    # voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
    ROOT = 'dataset/ORIGA'
    voc_train = ImageFolder(root_path=ROOT, datasets='ORIGA')
예제 #8
0
            mask[mask == _validc] = self.class_map[_validc]
        return mask


if __name__ == '__main__':
    from dataloaders import custom_transforms as tr
    from dataloaders.utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt

    composed_transforms_tr = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.RandomScale((0.5, 0.75)),
        tr.RandomCrop((512, 1024)),
        tr.RandomRotate(5),
        tr.ToTensor()
    ])

    cityscapes_train = CityscapesSegmentation(split='train',
                                              transform=composed_transforms_tr)

    dataloader = DataLoader(cityscapes_train,
                            batch_size=2,
                            shuffle=True,
                            num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
            gt = sample['label'].numpy()
예제 #9
0
def main():
    # Add default values to all parameters
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')
    parser.add_argument(
        '--coefficient', type=float, default=0.01, help='balance coefficient'
    )
    parser.add_argument(
        '--boundary-exist', type=bool, default=True, help='whether or not using boundary branch'
    )
    parser.add_argument(
        '--dataset', type=str, default='refuge', help='folder id contain images ROIs to train or validation'
    )
    parser.add_argument(
        '--batch-size', type=int, default=12, help='batch size for training the model'
    )
    # parser.add_argument(
    #     '--group-num', type=int, default=1, help='group number for group normalization'
    # )
    parser.add_argument(
        '--max-epoch', type=int, default=300, help='max epoch'
    )
    parser.add_argument(
        '--stop-epoch', type=int, default=300, help='stop epoch'
    )
    parser.add_argument(
        '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN'
    )
    parser.add_argument(
        '--interval-validate', type=int, default=1, help='interval epoch number to valide the model'
    )
    parser.add_argument(
        '--lr-gen', type=float, default=1e-3, help='learning rate',
    )
    parser.add_argument(
        '--lr-dis', type=float, default=2.5e-5, help='learning rate',
    )
    parser.add_argument(
        '--lr-decrease-rate', type=float, default=0.2, help='ratio multiplied to initial lr',
    )
    parser.add_argument(
        '--weight-decay', type=float, default=0.0005, help='weight decay',
    )
    parser.add_argument(
        '--momentum', type=float, default=0.9, help='momentum',
    )
    parser.add_argument(
        '--data-dir',
        default='./fundus/',
        help='data root path'
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=False,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )

    args = parser.parse_args()
    args.model = 'MobileNetV2'

    now = datetime.now()
    args.out = osp.join(here, 'logs', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f'))
    os.makedirs(args.out)

    # save training hyperparameters or/and settings
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(2020)
    if cuda:
        torch.cuda.manual_seed(2020)
    
    import random
    import numpy as np
    random.seed(2020)
    np.random.seed(2020)

    # 1. loading data
    composed_transforms_train = transforms.Compose([
        tr.RandomScaleCrop(512),
        tr.RandomRotate(),
        tr.RandomFlip(),
        tr.elastic_transform(),
        tr.add_salt_pepper_noise(),
        tr.adjust_light(),
        tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_val = transforms.Compose([
        tr.RandomCrop(512),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    data_train = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='train',
                                       transform=composed_transforms_train)
    dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                  pin_memory=True)
    data_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='testval',
                                     transform=composed_transforms_val)
    dataloader_val = DataLoader(data_val, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
    # domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train',
    #                                    transform=composed_transforms_ts)
    # domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2,
    #                                pin_memory=True)

    # 2. model
    model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride,
                        sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda()

    model_bd = BoundaryDiscriminator().cuda()
    model_mask = MaskDiscriminator().cuda()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer
    optim_gen = torch.optim.Adam(
        model_gen.parameters(),
        lr=args.lr_gen,
        betas=(0.9, 0.99)
    )
    optim_bd = torch.optim.SGD(
        model_bd.parameters(),
        lr=args.lr_dis,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    optim_mask = torch.optim.SGD(
        model_mask.parameters(),
        lr=args.lr_dis,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # breakpoint recovery
    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model_gen.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model_gen.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_bd_state_dict']
        model_dict = model_bd.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model_bd.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_mask_state_dict']
        model_dict = model_mask.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model_mask.load_state_dict(model_dict)

        start_epoch = checkpoint['epoch'] + 1
        start_iteration = checkpoint['iteration'] + 1
        optim_gen.load_state_dict(checkpoint['optim_state_dict'])
        optim_bd.load_state_dict(checkpoint['optim_bd_state_dict'])
        optim_mask.load_state_dict(checkpoint['optim_mask_state_dict'])

    trainer = Trainer.Trainer(
        cuda=cuda,
        model_gen=model_gen,
        model_bd=model_bd,
        model_mask=model_mask,
        optimizer_gen=optim_gen,
        optim_bd=optim_bd,
        optim_mask=optim_mask,
        lr_gen=args.lr_gen,
        lr_dis=args.lr_dis,
        lr_decrease_rate=args.lr_decrease_rate,
        train_loader=dataloader_train,
        validation_loader=dataloader_val,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
        warmup_epoch=args.warmup_epoch,
        coefficient=args.coefficient,
        boundary_exist=args.boundary_exist
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
예제 #10
0
파일: train.py 프로젝트: zlannnn/BEAL
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')

    # configurations (same configuration as original work)
    # https://github.com/shelhamer/fcn.berkeleyvision.org
    parser.add_argument('--datasetS',
                        type=str,
                        default='refuge',
                        help='test folder id contain images ROIs to test')
    parser.add_argument('--datasetT',
                        type=str,
                        default='Drishti-GS',
                        help='refuge / Drishti-GS/ RIM-ONE_r3')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        help='batch size for training the model')
    parser.add_argument('--group-num',
                        type=int,
                        default=1,
                        help='group number for group normalization')
    parser.add_argument('--max-epoch', type=int, default=200, help='max epoch')
    parser.add_argument('--stop-epoch',
                        type=int,
                        default=200,
                        help='stop epoch')
    parser.add_argument('--warmup-epoch',
                        type=int,
                        default=-1,
                        help='warmup epoch begin train GAN')

    parser.add_argument('--interval-validate',
                        type=int,
                        default=10,
                        help='interval epoch number to valide the model')
    parser.add_argument(
        '--lr-gen',
        type=float,
        default=1e-3,
        help='learning rate',
    )
    parser.add_argument(
        '--lr-dis',
        type=float,
        default=2.5e-5,
        help='learning rate',
    )
    parser.add_argument(
        '--lr-decrease-rate',
        type=float,
        default=0.1,
        help='ratio multiplied to initial lr',
    )
    parser.add_argument(
        '--weight-decay',
        type=float,
        default=0.0005,
        help='weight decay',
    )
    parser.add_argument(
        '--momentum',
        type=float,
        default=0.99,
        help='momentum',
    )
    parser.add_argument('--data-dir',
                        default='/home/sjwang/ssd1T/fundus/domain_adaptation/',
                        help='data root path')
    parser.add_argument(
        '--pretrained-model',
        default='../../../models/pytorch/fcn16s_from_caffe.pth',
        help='pretrained model of FCN16s',
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=True,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )

    args = parser.parse_args()

    args.model = 'FCN8s'

    now = datetime.now()
    args.out = osp.join(here, 'logs', args.datasetT,
                        now.strftime('%Y%m%d_%H%M%S.%f'))

    os.makedirs(args.out)
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset
    composed_transforms_tr = transforms.Compose([
        tr.RandomScaleCrop(512),
        tr.RandomRotate(),
        tr.RandomFlip(),
        tr.elastic_transform(),
        tr.add_salt_pepper_noise(),
        tr.adjust_light(),
        tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.RandomCrop(512),
         tr.Normalize_tf(),
         tr.ToTensor()])

    domain = DL.FundusSegmentation(base_dir=args.data_dir,
                                   dataset=args.datasetS,
                                   split='train',
                                   transform=composed_transforms_tr)
    domain_loaderS = DataLoader(domain,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=2,
                                pin_memory=True)
    domain_T = DL.FundusSegmentation(base_dir=args.data_dir,
                                     dataset=args.datasetT,
                                     split='train',
                                     transform=composed_transforms_tr)
    domain_loaderT = DataLoader(domain_T,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=2,
                                pin_memory=True)
    domain_val = DL.FundusSegmentation(base_dir=args.data_dir,
                                       dataset=args.datasetT,
                                       split='train',
                                       transform=composed_transforms_ts)
    domain_loader_val = DataLoader(domain_val,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=2,
                                   pin_memory=True)

    # 2. model
    model_gen = DeepLab(num_classes=2,
                        backbone='mobilenet',
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn).cuda()

    model_dis = BoundaryDiscriminator().cuda()
    model_dis2 = UncertaintyDiscriminator().cuda()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer

    optim_gen = torch.optim.Adam(model_gen.parameters(),
                                 lr=args.lr_gen,
                                 betas=(0.9, 0.99))
    optim_dis = torch.optim.SGD(model_dis.parameters(),
                                lr=args.lr_dis,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    optim_dis2 = torch.optim.SGD(model_dis2.parameters(),
                                 lr=args.lr_dis,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)

    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model_gen.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model_gen.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_dis_state_dict']
        model_dict = model_dis.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_dis.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_dis2_state_dict']
        model_dict = model_dis2.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_dis2.load_state_dict(model_dict)

        start_epoch = checkpoint['epoch'] + 1
        start_iteration = checkpoint['iteration'] + 1
        optim_gen.load_state_dict(checkpoint['optim_state_dict'])
        optim_dis.load_state_dict(checkpoint['optim_dis_state_dict'])
        optim_dis2.load_state_dict(checkpoint['optim_dis2_state_dict'])
        optim_adv.load_state_dict(checkpoint['optim_adv_state_dict'])

    trainer = Trainer.Trainer(
        cuda=cuda,
        model_gen=model_gen,
        model_dis=model_dis,
        model_uncertainty_dis=model_dis2,
        optimizer_gen=optim_gen,
        optimizer_dis=optim_dis,
        optimizer_uncertainty_dis=optim_dis2,
        lr_gen=args.lr_gen,
        lr_dis=args.lr_dis,
        lr_decrease_rate=args.lr_decrease_rate,
        val_loader=domain_loader_val,
        domain_loaderS=domain_loaderS,
        domain_loaderT=domain_loaderT,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
        warmup_epoch=args.warmup_epoch,
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
		sample = {'image': _img, 'label': _target}
		if self.transform is not None: 
			sample = self.transform(sample)
		return sample 

	def _make_img_gt_point_pair(self, index):
		#Read image and target 
		_img = Image.open(self.images[index].convert('RGB'))
		_target = Image.open(self.categories[index])
		return _img, _target

	def __str__(self):
		return 'VOC2012(split=' + str(self.split) + ')'

if __name__ == '__main__':
	composed_transforms_tr = transforms.Compose([tr.RandomHorizontalFlip(), tr.RandomSized(512), tr.RandomRotate(15), tr.ToTensor()])

	voc_train = PascalVOC(split='train', transform=composed_transforms_tr)
	dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=2)

	for ii, sample in tqdm(enumberate(dataloader)):
		for jj in tqdm(range(sample["image"].size()[0])):
			img = sample['image'].numpy()
			gt = sample['label'].numpy()
			tmp = np.array(get[jj]).astype(np.uint8)
			tmp = np.squeeze(tmp, axis=0)
			segmap = decode_segmap(tmp, dataset = 'pascal')
			img_tmp = np.transpose(img[jj], axes=[1,2,0]).astype(np.uint8)
			plt.figure()
			plt.title('display')
			plt.subplot(211)