def transform_tr(self, sample): """Transformations for images sample: {image:img, annotation:ann} Note: the mean and std is from imagenet """ if self.args.no_flip: composed_transforms = transforms.Compose([ tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, scale_ratio=self.args.scale_ratio, fill=0), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor() ]) return composed_transforms(sample) else: composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, scale_ratio=self.args.scale_ratio, fill=0), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor() ]) return composed_transforms(sample)
def transform_tr(self, sample): # if (sample['image'].width>self.args.base_size*2) and (sample['image'].height>self.args.base_size*2): # composed_transforms = transforms.Compose([ # tr.RandomHorizontalFlip(), # tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), # tr.RandomGaussianBlur(), # tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # tr.ToTensor()]) # else: # composed_transforms = transforms.Compose([ # # tr.FixScaleCrop(crop_size=self.args.crop_size), # tr.RandomHorizontalFlip(), # tr.RandomGaussianBlur(), # tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # tr.ToTensor()]) composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), tr.RandomGaussianBlur(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor() ]) return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample)
def transform_tr(self, sample): if random.random() > 0.5: if random.random() > 0.5: tr_function = tr.FixScaleCrop else: tr_function = tr.FixedResize composed_transforms = transforms.Compose( [ tr_function(self.args.crop_size), tr.RandomGaussianBlur(), tr.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ), tr.ToTensor(), ] ) else: composed_transforms = transforms.Compose( [ tr.RandomScaleCrop( base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255, ), tr.RandomGaussianBlur(), tr.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ), tr.ToTensor(), ] ) return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), tr.RandomGaussianBlur(), tr.Normalize(mean=(0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)), tr.ToTensor()])
def transform_tr(self, sample): # eventually, according to the condition of split in self.split, then split == 'train' composed_transforms = transforms.Compose([ # define transform_tr tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), # random scale crop, we have to calcualte base_size and crop_size based on argparse tr.RandomGaussianBlur(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample) # return composed_transforms
def transform(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.cfg.DATASET.BASE_SIZE, crop_size=self.cfg.DATASET.CROP_SIZE), tr.RandomGaussianBlur(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.FixedResize(size=(1024, 2048)), tr.ColorJitter(), tr.RandomGaussianBlur(), tr.RandomMotionBlur(), tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), tr.RandomGaussianBlur(), tr.Resize_normalize_train(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) return composed_transforms(sample)
def transform_tr_part1_1(self, sample): if self.args.use_small: composed_transforms = transforms.Compose( [tr.FixScaleCrop(crop_size=self.args.crop_size)]) else: composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size) ]) # Zhiwei return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.base_size, crop_size=self.crop_size, fill=255), tr.RandomDarken(self.cfg, self.darken), #tr.RandomGaussianBlur(), #TODO Not working for depth channel tr.Normalize(mean=self.data_mean, std=self.data_std), tr.ToTensor() ]) return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip( ), # given PIL image randomly with a given probability tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), tr.RandomGaussianBlur(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor() ]) return composed_transforms(sample)
def transform_tr(self, sample): """Image transformations for training""" targs = self.transform method = targs["method"] pars = targs["parameters"] composed_transforms = transforms.Compose([ tr.FixedResize(size=pars["outSize"]), tr.RandomRotate(degree=(90)), tr.RandomScaleCrop(baseSize=pars["baseSize"], cropSize=pars["outSize"], fill=255), tr.Normalize(mean=pars["mean"], std=pars["std"]), tr.ToTensor()]) return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ #tr.RandomHorizontalFlip(), tr.RandomRotate(degree=random.randint(15, 350)), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), tr.RandomGaussianBlur(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor() ]) return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=513, crop_size=513), tr.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3, gamma=0.3), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor() ]) return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), # tr.RandomGaussianBlur(), # tr.FixedResize(self.args.crop_size), # tr.RandomCrop(self.args.crop_size), # tr.RandomCutout(n_holes=1, cut_size=128), # tr.RandomRotate(30), tr.RandomRotate_v2(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor() ]) return composed_transforms(sample)
def transform_tr(self, sample): ''' Transform the given training sample. @param sample: The given training sample. ''' composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), tr.RandomGaussianBlur(), tf.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor() ]) return composed_transforms(sample)
def transform_tr(self, sample): color_transforms = [ transforms.RandomApply([transforms.ColorJitter(brightness=0.1) ]), # brightness transforms.RandomApply([transforms.ColorJitter(contrast=0.1) ]), # contrast transforms.RandomApply([transforms.ColorJitter(saturation=0.1) ]), # saturation transforms.RandomApply([transforms.ColorJitter(hue=0.05)]) ] # hue joint_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), tr.equalize(), tr.RandomGaussianBlur(), tr.RandomRotate(degree=7) ]) image_transforms = transforms.Compose([ transforms.RandomOrder(color_transforms), transforms.RandomGrayscale(p=0.3) ]) normalize_transforms = transforms.Compose([ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor() ]) tmp_sample = joint_transforms(sample) tmp_sample['image'] = image_transforms(tmp_sample['image']) tmp_sample = normalize_transforms(tmp_sample) return tmp_sample
def transform_tr(self, sample): """ composed transformers for training dataset :param sample: {'image': image, 'label': label} :return: """ img = sample['image'] img = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2)(img) sample = {'image': img, 'label': sample['label']} composed_transforms = transforms.Compose([ ct.RandomHorizontalFlip(), ct.RandomScaleCrop(base_size=self.base_size, crop_size=self.crop_size), # ct.RandomChangeBackground(), ct.RandomGaussianBlur(), ct.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ct.ToTensor() ]) return composed_transforms(sample)
def main(): # Add default values to all parameters parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id') parser.add_argument('--resume', default=None, help='checkpoint path') parser.add_argument( '--coefficient', type=float, default=0.01, help='balance coefficient' ) parser.add_argument( '--boundary-exist', type=bool, default=True, help='whether or not using boundary branch' ) parser.add_argument( '--dataset', type=str, default='refuge', help='folder id contain images ROIs to train or validation' ) parser.add_argument( '--batch-size', type=int, default=12, help='batch size for training the model' ) # parser.add_argument( # '--group-num', type=int, default=1, help='group number for group normalization' # ) parser.add_argument( '--max-epoch', type=int, default=300, help='max epoch' ) parser.add_argument( '--stop-epoch', type=int, default=300, help='stop epoch' ) parser.add_argument( '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN' ) parser.add_argument( '--interval-validate', type=int, default=1, help='interval epoch number to valide the model' ) parser.add_argument( '--lr-gen', type=float, default=1e-3, help='learning rate', ) parser.add_argument( '--lr-dis', type=float, default=2.5e-5, help='learning rate', ) parser.add_argument( '--lr-decrease-rate', type=float, default=0.2, help='ratio multiplied to initial lr', ) parser.add_argument( '--weight-decay', type=float, default=0.0005, help='weight decay', ) parser.add_argument( '--momentum', type=float, default=0.9, help='momentum', ) parser.add_argument( '--data-dir', default='./fundus/', help='data root path' ) parser.add_argument( '--out-stride', type=int, default=16, help='out-stride of deeplabv3+', ) parser.add_argument( '--sync-bn', type=bool, default=False, help='sync-bn in deeplabv3+', ) parser.add_argument( '--freeze-bn', type=bool, default=False, help='freeze batch normalization of deeplabv3+', ) args = parser.parse_args() args.model = 'MobileNetV2' now = datetime.now() args.out = osp.join(here, 'logs', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f')) os.makedirs(args.out) # save training hyperparameters or/and settings with open(osp.join(args.out, 'config.yaml'), 'w') as f: yaml.safe_dump(args.__dict__, f, default_flow_style=False) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) cuda = torch.cuda.is_available() torch.manual_seed(2020) if cuda: torch.cuda.manual_seed(2020) import random import numpy as np random.seed(2020) np.random.seed(2020) # 1. loading data composed_transforms_train = transforms.Compose([ tr.RandomScaleCrop(512), tr.RandomRotate(), tr.RandomFlip(), tr.elastic_transform(), tr.add_salt_pepper_noise(), tr.adjust_light(), tr.eraser(), tr.Normalize_tf(), tr.ToTensor() ]) composed_transforms_val = transforms.Compose([ tr.RandomCrop(512), tr.Normalize_tf(), tr.ToTensor() ]) data_train = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='train', transform=composed_transforms_train) dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) data_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='testval', transform=composed_transforms_val) dataloader_val = DataLoader(data_val, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) # domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train', # transform=composed_transforms_ts) # domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2, # pin_memory=True) # 2. model model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() model_bd = BoundaryDiscriminator().cuda() model_mask = MaskDiscriminator().cuda() start_epoch = 0 start_iteration = 0 # 3. optimizer optim_gen = torch.optim.Adam( model_gen.parameters(), lr=args.lr_gen, betas=(0.9, 0.99) ) optim_bd = torch.optim.SGD( model_bd.parameters(), lr=args.lr_dis, momentum=args.momentum, weight_decay=args.weight_decay ) optim_mask = torch.optim.SGD( model_mask.parameters(), lr=args.lr_dis, momentum=args.momentum, weight_decay=args.weight_decay ) # breakpoint recovery if args.resume: checkpoint = torch.load(args.resume) pretrained_dict = checkpoint['model_state_dict'] model_dict = model_gen.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model_gen.load_state_dict(model_dict) pretrained_dict = checkpoint['model_bd_state_dict'] model_dict = model_bd.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model_bd.load_state_dict(model_dict) pretrained_dict = checkpoint['model_mask_state_dict'] model_dict = model_mask.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model_mask.load_state_dict(model_dict) start_epoch = checkpoint['epoch'] + 1 start_iteration = checkpoint['iteration'] + 1 optim_gen.load_state_dict(checkpoint['optim_state_dict']) optim_bd.load_state_dict(checkpoint['optim_bd_state_dict']) optim_mask.load_state_dict(checkpoint['optim_mask_state_dict']) trainer = Trainer.Trainer( cuda=cuda, model_gen=model_gen, model_bd=model_bd, model_mask=model_mask, optimizer_gen=optim_gen, optim_bd=optim_bd, optim_mask=optim_mask, lr_gen=args.lr_gen, lr_dis=args.lr_dis, lr_decrease_rate=args.lr_decrease_rate, train_loader=dataloader_train, validation_loader=dataloader_val, out=args.out, max_epoch=args.max_epoch, stop_epoch=args.stop_epoch, interval_validate=args.interval_validate, batch_size=args.batch_size, warmup_epoch=args.warmup_epoch, coefficient=args.coefficient, boundary_exist=args.boundary_exist ) trainer.epoch = start_epoch trainer.iteration = start_iteration trainer.train()
def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id') parser.add_argument('--resume', default=None, help='checkpoint path') parser.add_argument( '--datasetTrain', nargs='+', type=int, default=1, help='train folder id contain images ROIs to train range from [1,2,3,4]' ) parser.add_argument( '--datasetTest', nargs='+', type=int, default=1, help='test folder id contain images ROIs to test one of [1,2,3,4]') parser.add_argument('--batch-size', type=int, default=8, help='batch size for training the model') parser.add_argument('--group-num', type=int, default=1, help='group number for group normalization') parser.add_argument('--max-epoch', type=int, default=120, help='max epoch') parser.add_argument('--stop-epoch', type=int, default=80, help='stop epoch') parser.add_argument('--interval-validate', type=int, default=10, help='interval epoch number to valide the model') parser.add_argument( '--lr', type=float, default=1e-3, help='learning rate', ) parser.add_argument('--lr-decrease-rate', type=float, default=0.2, help='ratio multiplied to initial lr') parser.add_argument( '--lam', type=float, default=0.9, help='momentum of memory update', ) parser.add_argument('--data-dir', default='../../../../Dataset/Fundus/', help='data root path') parser.add_argument( '--pretrained-model', default='../../../models/pytorch/fcn16s_from_caffe.pth', help='pretrained model of FCN16s', ) parser.add_argument( '--out-stride', type=int, default=16, help='out-stride of deeplabv3+', ) args = parser.parse_args() now = datetime.now() args.out = osp.join(local_path, 'logs', 'test' + str(args.datasetTest[0]), 'lam' + str(args.lam), now.strftime('%Y%m%d_%H%M%S.%f')) os.makedirs(args.out) with open(osp.join(args.out, 'config.yaml'), 'w') as f: yaml.safe_dump(args.__dict__, f, default_flow_style=False) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) cuda = torch.cuda.is_available() torch.cuda.manual_seed(1337) # 1. dataset composed_transforms_tr = transforms.Compose([ tr.RandomScaleCrop(256), # tr.RandomCrop(512), # tr.RandomRotate(), # tr.RandomFlip(), # tr.elastic_transform(), # tr.add_salt_pepper_noise(), # tr.adjust_light(), # tr.eraser(), tr.Normalize_tf(), tr.ToTensor() ]) composed_transforms_ts = transforms.Compose( [tr.RandomCrop(256), tr.Normalize_tf(), tr.ToTensor()]) domain = DL.FundusSegmentation(base_dir=args.data_dir, phase='train', splitid=args.datasetTrain, transform=composed_transforms_tr) train_loader = DataLoader(domain, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) domain_val = DL.FundusSegmentation(base_dir=args.data_dir, phase='test', splitid=args.datasetTest, transform=composed_transforms_ts) val_loader = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) # 2. model model = DeepLab(num_classes=2, num_domain=3, backbone='mobilenet', output_stride=args.out_stride, lam=args.lam).cuda() print('parameter numer:', sum([p.numel() for p in model.parameters()])) # load weights if args.resume: checkpoint = torch.load(args.resume) pretrained_dict = checkpoint['model_state_dict'] model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) print('Before ', model.centroids.data) model.centroids.data = centroids_init(model, args.data_dir, args.datasetTrain, composed_transforms_ts) print('Before ', model.centroids.data) # model.freeze_para() start_epoch = 0 start_iteration = 0 # 3. optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99)) trainer = Trainer.Trainer( cuda=cuda, model=model, lr=args.lr, lr_decrease_rate=args.lr_decrease_rate, train_loader=train_loader, val_loader=val_loader, optim=optim, out=args.out, max_epoch=args.max_epoch, stop_epoch=args.stop_epoch, interval_validate=args.interval_validate, batch_size=args.batch_size, ) trainer.epoch = start_epoch trainer.iteration = start_iteration trainer.train()
def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id') parser.add_argument('--resume', default=None, help='checkpoint path') # configurations (same configuration as original work) # https://github.com/shelhamer/fcn.berkeleyvision.org parser.add_argument('--datasetS', type=str, default='refuge', help='test folder id contain images ROIs to test') parser.add_argument('--datasetT', type=str, default='Drishti-GS', help='refuge / Drishti-GS/ RIM-ONE_r3') parser.add_argument('--batch-size', type=int, default=8, help='batch size for training the model') parser.add_argument('--group-num', type=int, default=1, help='group number for group normalization') parser.add_argument('--max-epoch', type=int, default=200, help='max epoch') parser.add_argument('--stop-epoch', type=int, default=200, help='stop epoch') parser.add_argument('--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN') parser.add_argument('--interval-validate', type=int, default=10, help='interval epoch number to valide the model') parser.add_argument( '--lr-gen', type=float, default=1e-3, help='learning rate', ) parser.add_argument( '--lr-dis', type=float, default=2.5e-5, help='learning rate', ) parser.add_argument( '--lr-decrease-rate', type=float, default=0.1, help='ratio multiplied to initial lr', ) parser.add_argument( '--weight-decay', type=float, default=0.0005, help='weight decay', ) parser.add_argument( '--momentum', type=float, default=0.99, help='momentum', ) parser.add_argument('--data-dir', default='/home/sjwang/ssd1T/fundus/domain_adaptation/', help='data root path') parser.add_argument( '--pretrained-model', default='../../../models/pytorch/fcn16s_from_caffe.pth', help='pretrained model of FCN16s', ) parser.add_argument( '--out-stride', type=int, default=16, help='out-stride of deeplabv3+', ) parser.add_argument( '--sync-bn', type=bool, default=True, help='sync-bn in deeplabv3+', ) parser.add_argument( '--freeze-bn', type=bool, default=False, help='freeze batch normalization of deeplabv3+', ) args = parser.parse_args() args.model = 'FCN8s' now = datetime.now() args.out = osp.join(here, 'logs', args.datasetT, now.strftime('%Y%m%d_%H%M%S.%f')) os.makedirs(args.out) with open(osp.join(args.out, 'config.yaml'), 'w') as f: yaml.safe_dump(args.__dict__, f, default_flow_style=False) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) cuda = torch.cuda.is_available() torch.manual_seed(1337) if cuda: torch.cuda.manual_seed(1337) # 1. dataset composed_transforms_tr = transforms.Compose([ tr.RandomScaleCrop(512), tr.RandomRotate(), tr.RandomFlip(), tr.elastic_transform(), tr.add_salt_pepper_noise(), tr.adjust_light(), tr.eraser(), tr.Normalize_tf(), tr.ToTensor() ]) composed_transforms_ts = transforms.Compose( [tr.RandomCrop(512), tr.Normalize_tf(), tr.ToTensor()]) domain = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetS, split='train', transform=composed_transforms_tr) domain_loaderS = DataLoader(domain, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) domain_T = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train', transform=composed_transforms_tr) domain_loaderT = DataLoader(domain_T, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train', transform=composed_transforms_ts) domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) # 2. model model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() model_dis = BoundaryDiscriminator().cuda() model_dis2 = UncertaintyDiscriminator().cuda() start_epoch = 0 start_iteration = 0 # 3. optimizer optim_gen = torch.optim.Adam(model_gen.parameters(), lr=args.lr_gen, betas=(0.9, 0.99)) optim_dis = torch.optim.SGD(model_dis.parameters(), lr=args.lr_dis, momentum=args.momentum, weight_decay=args.weight_decay) optim_dis2 = torch.optim.SGD(model_dis2.parameters(), lr=args.lr_dis, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: checkpoint = torch.load(args.resume) pretrained_dict = checkpoint['model_state_dict'] model_dict = model_gen.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model_gen.load_state_dict(model_dict) pretrained_dict = checkpoint['model_dis_state_dict'] model_dict = model_dis.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model_dis.load_state_dict(model_dict) pretrained_dict = checkpoint['model_dis2_state_dict'] model_dict = model_dis2.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model_dis2.load_state_dict(model_dict) start_epoch = checkpoint['epoch'] + 1 start_iteration = checkpoint['iteration'] + 1 optim_gen.load_state_dict(checkpoint['optim_state_dict']) optim_dis.load_state_dict(checkpoint['optim_dis_state_dict']) optim_dis2.load_state_dict(checkpoint['optim_dis2_state_dict']) optim_adv.load_state_dict(checkpoint['optim_adv_state_dict']) trainer = Trainer.Trainer( cuda=cuda, model_gen=model_gen, model_dis=model_dis, model_uncertainty_dis=model_dis2, optimizer_gen=optim_gen, optimizer_dis=optim_dis, optimizer_uncertainty_dis=optim_dis2, lr_gen=args.lr_gen, lr_dis=args.lr_dis, lr_decrease_rate=args.lr_decrease_rate, val_loader=domain_loader_val, domain_loaderS=domain_loaderS, domain_loaderT=domain_loaderT, out=args.out, max_epoch=args.max_epoch, stop_epoch=args.stop_epoch, interval_validate=args.interval_validate, batch_size=args.batch_size, warmup_epoch=args.warmup_epoch, ) trainer.epoch = start_epoch trainer.iteration = start_iteration trainer.train()