def centroids_init(model, data_dir, datasetTrain, composed_transforms): centroids = torch.zeros(3, 304, 64, 64).cuda() # 3 means the number of source domains model.eval() # Calculate initial centroids only on training data. with torch.set_grad_enabled(False): count = 0 # tranverse each training source domain for index in datasetTrain: domain = DL.FundusSegmentation(base_dir=data_dir, phase='train', splitid=[index], transform=composed_transforms) dataloder = DataLoader(domain, batch_size=1, shuffle=True, num_workers=2, pin_memory=True) for id, sample in tqdm(enumerate(dataloder)): sample = sample[0] inputs = sample['image'].cuda() features = model(inputs, extract_feature=True) # Calculate the sum features from the same domain centroids[count:count + 1] += features # Average summed features with class count centroids[count] /= torch.tensor(len(dataloder)).float().unsqueeze( -1).unsqueeze(-1).unsqueeze(-1).cuda() count += 1 # Calculate the mean features for each domain ave = torch.mean(torch.mean(centroids, 3, True), 2, True) # size [3, 304] return ave.expand_as(centroids).contiguous() # size [3, 304, 64, 64]
def main(): parser = argparse.ArgumentParser() parser.add_argument( '--model-file', type=str, default='./logs/train2/20181202_160326.365442/checkpoint_9.pth.tar', help='Model path') parser.add_argument('--dataset', type=str, default='Drishti-GS', help='test folder id contain images ROIs to test') parser.add_argument('-g', '--gpu', type=int, default=0) parser.add_argument('--data-dir', default='/home/sjwang/ssd1T/fundus/domain_adaptation/', help='data root path') parser.add_argument( '--out-stride', type=int, default=16, help='out-stride of deeplabv3+', ) parser.add_argument( '--save-root-ent', type=str, default='./results/ent/', help='path to save ent', ) parser.add_argument( '--save-root-mask', type=str, default='./results/mask/', help='path to save mask', ) parser.add_argument( '--sync-bn', type=bool, default=True, help='sync-bn in deeplabv3+', ) parser.add_argument( '--freeze-bn', type=bool, default=False, help='freeze batch normalization of deeplabv3+', ) parser.add_argument('--test-prediction-save-path', type=str, default='./results/baseline/', help='Path root for test image and mask') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) model_file = args.model_file # 1. dataset composed_transforms_test = transforms.Compose( [tr.Normalize_tf(), tr.ToTensor()]) db_test = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='test', transform=composed_transforms_test) test_loader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) # 2. model model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() if torch.cuda.is_available(): model = model.cuda() print('==> Loading %s model file: %s' % (model.__class__.__name__, model_file)) checkpoint = torch.load(model_file) try: model.load_state_dict(model_data) pretrained_dict = checkpoint['model_state_dict'] model_dict = model_gen.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model_gen.load_state_dict(model_dict) except Exception: model.load_state_dict(checkpoint['model_state_dict']) model.eval() print('==> Evaluating with %s' % (args.dataset)) val_cup_dice = 0.0 val_disc_dice = 0.0 timestamp_start = \ datetime.now(pytz.timezone('Asia/Hong_Kong')) for batch_idx, (sample) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader), ncols=80, leave=False): data = sample['image'] target = sample['map'] img_name = sample['img_name'] if torch.cuda.is_available(): data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) prediction, boundary = model(data) prediction = torch.nn.functional.interpolate(prediction, size=(target.size()[2], target.size()[3]), mode="bilinear") boundary = torch.nn.functional.interpolate(boundary, size=(target.size()[2], target.size()[3]), mode="bilinear") data = torch.nn.functional.interpolate(data, size=(target.size()[2], target.size()[3]), mode="bilinear") prediction = torch.sigmoid(prediction) boundary = torch.sigmoid(boundary) draw_ent(prediction.data.cpu()[0].numpy(), os.path.join(args.save_root_ent, args.dataset), img_name[0]) draw_mask(prediction.data.cpu()[0].numpy(), os.path.join(args.save_root_mask, args.dataset), img_name[0]) draw_boundary(boundary.data.cpu()[0].numpy(), os.path.join(args.save_root_mask, args.dataset), img_name[0]) prediction = postprocessing(prediction.data.cpu()[0], dataset=args.dataset) target_numpy = target.data.cpu() cup_dice = dice_coefficient_numpy(prediction[0, ...], target_numpy[0, 0, ...]) disc_dice = dice_coefficient_numpy(prediction[1, ...], target_numpy[0, 1, ...]) val_cup_dice += cup_dice val_disc_dice += disc_dice imgs = data.data.cpu() for img, lt, lp in zip(imgs, target_numpy, [prediction]): img, lt = untransform(img, lt) save_per_img(img.numpy().transpose(1, 2, 0), os.path.join(args.test_prediction_save_path, args.dataset), img_name[0], lp, mask_path=None, ext="bmp") val_cup_dice /= len(test_loader) val_disc_dice /= len(test_loader) print('''\n==>val_cup_dice : {0}'''.format(val_cup_dice)) print('''\n==>val_disc_dice : {0}'''.format(val_disc_dice)) with open(osp.join(args.test_prediction_save_path, 'test_log.csv'), 'a') as f: elapsed_time = (datetime.now(pytz.timezone('Asia/Hong_Kong')) - timestamp_start).total_seconds() log = [[args.model_file] + ['cup dice coefficence: '] + \ [val_cup_dice] + ['disc dice coefficence: '] + \ [val_disc_dice] + [elapsed_time]] log = map(str, log) f.write(','.join(log) + '\n')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--model-file', type=str, default='./logs/test1/20190506_221021.177567/checkpoint_200.pth.tar', help='Model path') parser.add_argument('--datasetTest', type=list, default=[1], help='test folder id contain images ROIs to test') parser.add_argument('--dataset', type=str, default='test', help='test folder id contain images ROIs to test') parser.add_argument('-g', '--gpu', type=int, default=0) parser.add_argument('--data-dir', default='../../../../Dataset/Fundus/', help='data root path') parser.add_argument('--out-stride', type=int, default=16, help='out-stride of deeplabv3+',) parser.add_argument('--sync-bn', type=bool, default=False, help='sync-bn in deeplabv3+') parser.add_argument('--freeze-bn', type=bool, default=False, help='freeze batch normalization of deeplabv3+') parser.add_argument('--movingbn', type=bool, default=False, help='moving batch normalization of deeplabv3+ in the test phase',) parser.add_argument('--test-prediction-save-path', type=str, default='./results/rebuttle-0401/', help='Path root for test image and mask') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) model_file = args.model_file output_path = os.path.join(args.test_prediction_save_path, 'test' + str(args.datasetTest[0]), args.model_file.split('/')[-2]) # 1. dataset composed_transforms_test = transforms.Compose([ tr.Normalize_tf(), tr.ToTensor() ]) db_test = DL.FundusSegmentation(base_dir=args.data_dir, phase='test', splitid=args.datasetTest, transform=composed_transforms_test, state='prediction') batch_size = 12 test_loader = DataLoader(db_test, batch_size=batch_size, shuffle=False, num_workers=1) # 2. model model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() if torch.cuda.is_available(): model = model.cuda() print('==> Loading %s model file: %s' % (model.__class__.__name__, model_file)) # model_data = torch.load(model_file) checkpoint = torch.load(model_file) pretrained_dict = checkpoint['model_state_dict'] model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) if args.movingbn: model.train() else: model.eval() val_cup_dice = 0.0 val_disc_dice = 0.0 total_hd_OC = 0.0 total_hd_OD = 0.0 total_asd_OC = 0.0 total_asd_OD = 0.0 timestamp_start = datetime.now(pytz.timezone('Asia/Hong_Kong')) total_num = 0 OC = [] OD = [] for batch_idx, (sample) in tqdm.tqdm(enumerate(test_loader),total=len(test_loader),ncols=80, leave=False): data = sample['image'] target = sample['label'] img_name = sample['img_name'] if torch.cuda.is_available(): data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) prediction, dc, sel, _ = model(data) prediction = torch.nn.functional.interpolate(prediction, size=(target.size()[2], target.size()[3]), mode="bilinear") data = torch.nn.functional.interpolate(data, size=(target.size()[2], target.size()[3]), mode="bilinear") target_numpy = target.data.cpu() imgs = data.data.cpu() hd_OC = 100 asd_OC = 100 hd_OD = 100 asd_OD = 100 for i in range(prediction.shape[0]): prediction_post = postprocessing(prediction[i], dataset=args.dataset) cup_dice, disc_dice = dice_coeff_2label(prediction_post, target[i]) OC.append(cup_dice) OD.append(disc_dice) if np.sum(prediction_post[0, ...]) < 1e-4: hd_OC = 100 asd_OC = 100 else: hd_OC = binary.hd95(np.asarray(prediction_post[0, ...], dtype=np.bool), np.asarray(target_numpy[i, 0, ...], dtype=np.bool)) asd_OC = binary.asd(np.asarray(prediction_post[0, ...], dtype=np.bool), np.asarray(target_numpy[i, 0, ...], dtype=np.bool)) if np.sum(prediction_post[0, ...]) < 1e-4: hd_OD = 100 asd_OD = 100 else: hd_OD = binary.hd95(np.asarray(prediction_post[1, ...], dtype=np.bool), np.asarray(target_numpy[i, 1, ...], dtype=np.bool)) asd_OD = binary.asd(np.asarray(prediction_post[1, ...], dtype=np.bool), np.asarray(target_numpy[i, 1, ...], dtype=np.bool)) val_cup_dice += cup_dice val_disc_dice += disc_dice total_hd_OC += hd_OC total_hd_OD += hd_OD total_asd_OC += asd_OC total_asd_OD += asd_OD total_num += 1 for img, lt, lp in zip([imgs[i]], [target_numpy[i]], [prediction_post]): img, lt = utils.untransform(img, lt) save_per_img(img.numpy().transpose(1, 2, 0), output_path, img_name[i], lp, lt, mask_path=None, ext="bmp") print('OC:', OC) print('OD:', OD) import csv with open('Dice_results.csv', 'a+') as result_file: wr = csv.writer(result_file, dialect='excel') for index in range(len(OC)): wr.writerow([OC[index], OD[index]]) val_cup_dice /= total_num val_disc_dice /= total_num total_hd_OC /= total_num total_asd_OC /= total_num total_hd_OD /= total_num total_asd_OD /= total_num print('''\n==>val_cup_dice : {0}'''.format(val_cup_dice)) print('''\n==>val_disc_dice : {0}'''.format(val_disc_dice)) print('''\n==>average_hd_OC : {0}'''.format(total_hd_OC)) print('''\n==>average_hd_OD : {0}'''.format(total_hd_OD)) print('''\n==>ave_asd_OC : {0}'''.format(total_asd_OC)) print('''\n==>average_asd_OD : {0}'''.format(total_asd_OD)) with open(osp.join(output_path, '../test' + str(args.datasetTest[0]) + '_log.csv'), 'a') as f: elapsed_time = ( datetime.now(pytz.timezone('Asia/Hong_Kong')) - timestamp_start).total_seconds() log = [['batch-size: '] + [batch_size] + [args.model_file] + ['cup dice coefficence: '] + \ [val_cup_dice] + ['disc dice coefficence: '] + \ [val_disc_dice] + ['average_hd_OC: '] + \ [total_hd_OC] + ['average_hd_OD: '] + \ [total_hd_OD] + ['ave_asd_OC: '] + \ [total_asd_OC] + ['average_asd_OD: '] + \ [total_asd_OD] + [elapsed_time]] log = map(str, log) f.write(','.join(log) + '\n')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--model-file', type=str, default='./logs/refuge_weights.tar', help='Model path') parser.add_argument( '--dataset', type=str, default='Drishti-GS', help='test folder id contain images ROIs to test' ) parser.add_argument('-g', '--gpu', type=int, default=0) parser.add_argument( '--resize', type=int, default=800, help='image resize') parser.add_argument( '--data-dir', default='./fundus/', help='data root path' ) parser.add_argument( '--mask-dir', required=True, default='./fundus/Drishti-GS/test/mask', help='mask image path' ) parser.add_argument( '--out-stride', type=int, default=16, help='out-stride of deeplabv3+', ) parser.add_argument( '--save-root-ent', type=str, default='./results/ent/', help='path to save ent', ) parser.add_argument( '--save-root-mask', type=str, default='./results/mask/', help='path to save mask', ) parser.add_argument( '--sync-bn', type=bool, default=False, help='sync-bn in deeplabv3+', ) parser.add_argument( '--freeze-bn', type=bool, default=False, help='freeze batch normalization of deeplabv3+', ) parser.add_argument('--test-prediction-save-path', type=str, default='./results/baseline/', help='Path root for test image and mask') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) model_file = args.model_file # 1. dataset composed_transforms_test = transforms.Compose([ tr.Scale(args.resize), tr.Normalize_tf(), tr.ToTensor() ]) db_test = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='test', transform=composed_transforms_test) test_loader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) # 2. model model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() if torch.cuda.is_available(): model = model.cuda() print('==> Loading %s model file: %s' % (model.__class__.__name__, model_file)) checkpoint = torch.load(model_file) # try: # model.load_state_dict(checkpoint) # pretrained_dict = checkpoint['model_state_dict'] # model_dict = model.state_dict() # # 1. filter out unnecessary keys # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # # 2. overwrite entries in the existing state dict # model_dict.update(pretrained_dict) # # 3. load the new state dict # model.load_state_dict(model_dict) # except Exception: try: model.load_state_dict(checkpoint['model_state_dict']) except: raise FileNotFoundError('No checkpoint file exist...') model.eval() print('==> Evaluating with %s' % args.dataset) test_cup_dice = 0.0 test_disc_dice = 0.0 timestamp_start = \ datetime.now(pytz.timezone('Asia/Hong_Kong')) with torch.no_grad(): for batch_idx, (sample) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader), ncols=80, leave=False): data = sample['image'] target = sample['map'] img_name = sample['img_name'] if torch.cuda.is_available(): data, target = data.cuda(), target.cuda() # data, target = Variable(data), Variable(target) prediction, boundary = model(data) prediction = torch.nn.functional.interpolate(prediction, size=(target.size()[2], target.size()[3]), mode="bilinear") # boundary = torch.nn.functional.interpolate(boundary, size=(target.size()[2], target.size()[3]), # mode="bilinear") data = torch.nn.functional.interpolate(data, size=(target.size()[2], target.size()[3]), mode="bilinear") cup_dice, disc_dice = dice_coeff_2label(prediction, target) test_cup_dice += cup_dice test_disc_dice += disc_dice # boundary = torch.sigmoid(boundary) # # drawing figures # draw_ent(prediction.data.cpu()[0].numpy(), os.path.join(args.save_root_ent, args.dataset), img_name[0]) # draw_mask(prediction.data.cpu()[0].numpy(), os.path.join(args.save_root_mask, args.dataset), img_name[0]) # draw_boundary(boundary.data.cpu()[0].numpy(), os.path.join(args.save_root_mask, args.dataset), img_name[0]) prediction, ROI_mask = postprocessing(torch.sigmoid(prediction).data.cpu()[0], dataset=args.dataset) imgs = data.data.cpu() target_numpy = target.cpu().numpy() for img, lt, lp in zip(imgs, target_numpy, [prediction]): img, lt = untransform(img, lt) save_per_img(img.numpy().transpose(1, 2, 0), os.path.join(args.test_prediction_save_path, args.dataset), img_name[0], lp, lt, ROI_mask) test_cup_dice /= len(test_loader) test_disc_dice /= len(test_loader) print("test_cup_dice = ", test_cup_dice) print("test_disc_dice = ", test_disc_dice) # submit script _, _, mae_cdr = evaluate_segmentation_results(osp.join(args.test_prediction_save_path, args.dataset, 'pred_mask'), args.mask_dir, output_path="./", export_table=True) with open(osp.join(args.test_prediction_save_path, 'test_log.csv'), 'a') as f: elapsed_time = ( datetime.now(pytz.timezone('Asia/Hong_Kong')) - timestamp_start).total_seconds() log = [[args.model_file] + ['cup dice: '] + \ [test_cup_dice] + ['disc dice: '] + \ [test_disc_dice] + ['cdr: '] + \ [mae_cdr] + [elapsed_time]] log = map(str, log) f.write(','.join(log) + '\n')
def main(): # Add default values to all parameters parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id') parser.add_argument('--resume', default=None, help='checkpoint path') parser.add_argument( '--coefficient', type=float, default=0.01, help='balance coefficient' ) parser.add_argument( '--boundary-exist', type=bool, default=True, help='whether or not using boundary branch' ) parser.add_argument( '--dataset', type=str, default='refuge', help='folder id contain images ROIs to train or validation' ) parser.add_argument( '--batch-size', type=int, default=12, help='batch size for training the model' ) # parser.add_argument( # '--group-num', type=int, default=1, help='group number for group normalization' # ) parser.add_argument( '--max-epoch', type=int, default=300, help='max epoch' ) parser.add_argument( '--stop-epoch', type=int, default=300, help='stop epoch' ) parser.add_argument( '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN' ) parser.add_argument( '--interval-validate', type=int, default=1, help='interval epoch number to valide the model' ) parser.add_argument( '--lr-gen', type=float, default=1e-3, help='learning rate', ) parser.add_argument( '--lr-dis', type=float, default=2.5e-5, help='learning rate', ) parser.add_argument( '--lr-decrease-rate', type=float, default=0.2, help='ratio multiplied to initial lr', ) parser.add_argument( '--weight-decay', type=float, default=0.0005, help='weight decay', ) parser.add_argument( '--momentum', type=float, default=0.9, help='momentum', ) parser.add_argument( '--data-dir', default='./fundus/', help='data root path' ) parser.add_argument( '--out-stride', type=int, default=16, help='out-stride of deeplabv3+', ) parser.add_argument( '--sync-bn', type=bool, default=False, help='sync-bn in deeplabv3+', ) parser.add_argument( '--freeze-bn', type=bool, default=False, help='freeze batch normalization of deeplabv3+', ) args = parser.parse_args() args.model = 'MobileNetV2' now = datetime.now() args.out = osp.join(here, 'logs', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f')) os.makedirs(args.out) # save training hyperparameters or/and settings with open(osp.join(args.out, 'config.yaml'), 'w') as f: yaml.safe_dump(args.__dict__, f, default_flow_style=False) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) cuda = torch.cuda.is_available() torch.manual_seed(2020) if cuda: torch.cuda.manual_seed(2020) import random import numpy as np random.seed(2020) np.random.seed(2020) # 1. loading data composed_transforms_train = transforms.Compose([ tr.RandomScaleCrop(512), tr.RandomRotate(), tr.RandomFlip(), tr.elastic_transform(), tr.add_salt_pepper_noise(), tr.adjust_light(), tr.eraser(), tr.Normalize_tf(), tr.ToTensor() ]) composed_transforms_val = transforms.Compose([ tr.RandomCrop(512), tr.Normalize_tf(), tr.ToTensor() ]) data_train = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='train', transform=composed_transforms_train) dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) data_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='testval', transform=composed_transforms_val) dataloader_val = DataLoader(data_val, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) # domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train', # transform=composed_transforms_ts) # domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2, # pin_memory=True) # 2. model model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() model_bd = BoundaryDiscriminator().cuda() model_mask = MaskDiscriminator().cuda() start_epoch = 0 start_iteration = 0 # 3. optimizer optim_gen = torch.optim.Adam( model_gen.parameters(), lr=args.lr_gen, betas=(0.9, 0.99) ) optim_bd = torch.optim.SGD( model_bd.parameters(), lr=args.lr_dis, momentum=args.momentum, weight_decay=args.weight_decay ) optim_mask = torch.optim.SGD( model_mask.parameters(), lr=args.lr_dis, momentum=args.momentum, weight_decay=args.weight_decay ) # breakpoint recovery if args.resume: checkpoint = torch.load(args.resume) pretrained_dict = checkpoint['model_state_dict'] model_dict = model_gen.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model_gen.load_state_dict(model_dict) pretrained_dict = checkpoint['model_bd_state_dict'] model_dict = model_bd.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model_bd.load_state_dict(model_dict) pretrained_dict = checkpoint['model_mask_state_dict'] model_dict = model_mask.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model_mask.load_state_dict(model_dict) start_epoch = checkpoint['epoch'] + 1 start_iteration = checkpoint['iteration'] + 1 optim_gen.load_state_dict(checkpoint['optim_state_dict']) optim_bd.load_state_dict(checkpoint['optim_bd_state_dict']) optim_mask.load_state_dict(checkpoint['optim_mask_state_dict']) trainer = Trainer.Trainer( cuda=cuda, model_gen=model_gen, model_bd=model_bd, model_mask=model_mask, optimizer_gen=optim_gen, optim_bd=optim_bd, optim_mask=optim_mask, lr_gen=args.lr_gen, lr_dis=args.lr_dis, lr_decrease_rate=args.lr_decrease_rate, train_loader=dataloader_train, validation_loader=dataloader_val, out=args.out, max_epoch=args.max_epoch, stop_epoch=args.stop_epoch, interval_validate=args.interval_validate, batch_size=args.batch_size, warmup_epoch=args.warmup_epoch, coefficient=args.coefficient, boundary_exist=args.boundary_exist ) trainer.epoch = start_epoch trainer.iteration = start_iteration trainer.train()
def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id') parser.add_argument('--resume', default=None, help='checkpoint path') # configurations (same configuration as original work) # https://github.com/shelhamer/fcn.berkeleyvision.org parser.add_argument('--datasetS', type=str, default='refuge', help='test folder id contain images ROIs to test') parser.add_argument('--datasetT', type=str, default='Drishti-GS', help='refuge / Drishti-GS/ RIM-ONE_r3') parser.add_argument('--batch-size', type=int, default=8, help='batch size for training the model') parser.add_argument('--group-num', type=int, default=1, help='group number for group normalization') parser.add_argument('--max-epoch', type=int, default=200, help='max epoch') parser.add_argument('--stop-epoch', type=int, default=200, help='stop epoch') parser.add_argument('--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN') parser.add_argument('--interval-validate', type=int, default=10, help='interval epoch number to valide the model') parser.add_argument( '--lr-gen', type=float, default=1e-3, help='learning rate', ) parser.add_argument( '--lr-dis', type=float, default=2.5e-5, help='learning rate', ) parser.add_argument( '--lr-decrease-rate', type=float, default=0.1, help='ratio multiplied to initial lr', ) parser.add_argument( '--weight-decay', type=float, default=0.0005, help='weight decay', ) parser.add_argument( '--momentum', type=float, default=0.99, help='momentum', ) parser.add_argument('--data-dir', default='/home/sjwang/ssd1T/fundus/domain_adaptation/', help='data root path') parser.add_argument( '--pretrained-model', default='../../../models/pytorch/fcn16s_from_caffe.pth', help='pretrained model of FCN16s', ) parser.add_argument( '--out-stride', type=int, default=16, help='out-stride of deeplabv3+', ) parser.add_argument( '--sync-bn', type=bool, default=True, help='sync-bn in deeplabv3+', ) parser.add_argument( '--freeze-bn', type=bool, default=False, help='freeze batch normalization of deeplabv3+', ) args = parser.parse_args() args.model = 'FCN8s' now = datetime.now() args.out = osp.join(here, 'logs', args.datasetT, now.strftime('%Y%m%d_%H%M%S.%f')) os.makedirs(args.out) with open(osp.join(args.out, 'config.yaml'), 'w') as f: yaml.safe_dump(args.__dict__, f, default_flow_style=False) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) cuda = torch.cuda.is_available() torch.manual_seed(1337) if cuda: torch.cuda.manual_seed(1337) # 1. dataset composed_transforms_tr = transforms.Compose([ tr.RandomScaleCrop(512), tr.RandomRotate(), tr.RandomFlip(), tr.elastic_transform(), tr.add_salt_pepper_noise(), tr.adjust_light(), tr.eraser(), tr.Normalize_tf(), tr.ToTensor() ]) composed_transforms_ts = transforms.Compose( [tr.RandomCrop(512), tr.Normalize_tf(), tr.ToTensor()]) domain = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetS, split='train', transform=composed_transforms_tr) domain_loaderS = DataLoader(domain, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) domain_T = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train', transform=composed_transforms_tr) domain_loaderT = DataLoader(domain_T, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train', transform=composed_transforms_ts) domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) # 2. model model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() model_dis = BoundaryDiscriminator().cuda() model_dis2 = UncertaintyDiscriminator().cuda() start_epoch = 0 start_iteration = 0 # 3. optimizer optim_gen = torch.optim.Adam(model_gen.parameters(), lr=args.lr_gen, betas=(0.9, 0.99)) optim_dis = torch.optim.SGD(model_dis.parameters(), lr=args.lr_dis, momentum=args.momentum, weight_decay=args.weight_decay) optim_dis2 = torch.optim.SGD(model_dis2.parameters(), lr=args.lr_dis, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: checkpoint = torch.load(args.resume) pretrained_dict = checkpoint['model_state_dict'] model_dict = model_gen.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model_gen.load_state_dict(model_dict) pretrained_dict = checkpoint['model_dis_state_dict'] model_dict = model_dis.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model_dis.load_state_dict(model_dict) pretrained_dict = checkpoint['model_dis2_state_dict'] model_dict = model_dis2.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model_dis2.load_state_dict(model_dict) start_epoch = checkpoint['epoch'] + 1 start_iteration = checkpoint['iteration'] + 1 optim_gen.load_state_dict(checkpoint['optim_state_dict']) optim_dis.load_state_dict(checkpoint['optim_dis_state_dict']) optim_dis2.load_state_dict(checkpoint['optim_dis2_state_dict']) optim_adv.load_state_dict(checkpoint['optim_adv_state_dict']) trainer = Trainer.Trainer( cuda=cuda, model_gen=model_gen, model_dis=model_dis, model_uncertainty_dis=model_dis2, optimizer_gen=optim_gen, optimizer_dis=optim_dis, optimizer_uncertainty_dis=optim_dis2, lr_gen=args.lr_gen, lr_dis=args.lr_dis, lr_decrease_rate=args.lr_decrease_rate, val_loader=domain_loader_val, domain_loaderS=domain_loaderS, domain_loaderT=domain_loaderT, out=args.out, max_epoch=args.max_epoch, stop_epoch=args.stop_epoch, interval_validate=args.interval_validate, batch_size=args.batch_size, warmup_epoch=args.warmup_epoch, ) trainer.epoch = start_epoch trainer.iteration = start_iteration trainer.train()
def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id') parser.add_argument('--resume', default=None, help='checkpoint path') parser.add_argument( '--datasetTrain', nargs='+', type=int, default=1, help='train folder id contain images ROIs to train range from [1,2,3,4]' ) parser.add_argument( '--datasetTest', nargs='+', type=int, default=1, help='test folder id contain images ROIs to test one of [1,2,3,4]') parser.add_argument('--batch-size', type=int, default=8, help='batch size for training the model') parser.add_argument('--group-num', type=int, default=1, help='group number for group normalization') parser.add_argument('--max-epoch', type=int, default=120, help='max epoch') parser.add_argument('--stop-epoch', type=int, default=80, help='stop epoch') parser.add_argument('--interval-validate', type=int, default=10, help='interval epoch number to valide the model') parser.add_argument( '--lr', type=float, default=1e-3, help='learning rate', ) parser.add_argument('--lr-decrease-rate', type=float, default=0.2, help='ratio multiplied to initial lr') parser.add_argument( '--lam', type=float, default=0.9, help='momentum of memory update', ) parser.add_argument('--data-dir', default='../../../../Dataset/Fundus/', help='data root path') parser.add_argument( '--pretrained-model', default='../../../models/pytorch/fcn16s_from_caffe.pth', help='pretrained model of FCN16s', ) parser.add_argument( '--out-stride', type=int, default=16, help='out-stride of deeplabv3+', ) args = parser.parse_args() now = datetime.now() args.out = osp.join(local_path, 'logs', 'test' + str(args.datasetTest[0]), 'lam' + str(args.lam), now.strftime('%Y%m%d_%H%M%S.%f')) os.makedirs(args.out) with open(osp.join(args.out, 'config.yaml'), 'w') as f: yaml.safe_dump(args.__dict__, f, default_flow_style=False) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) cuda = torch.cuda.is_available() torch.cuda.manual_seed(1337) # 1. dataset composed_transforms_tr = transforms.Compose([ tr.RandomScaleCrop(256), # tr.RandomCrop(512), # tr.RandomRotate(), # tr.RandomFlip(), # tr.elastic_transform(), # tr.add_salt_pepper_noise(), # tr.adjust_light(), # tr.eraser(), tr.Normalize_tf(), tr.ToTensor() ]) composed_transforms_ts = transforms.Compose( [tr.RandomCrop(256), tr.Normalize_tf(), tr.ToTensor()]) domain = DL.FundusSegmentation(base_dir=args.data_dir, phase='train', splitid=args.datasetTrain, transform=composed_transforms_tr) train_loader = DataLoader(domain, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) domain_val = DL.FundusSegmentation(base_dir=args.data_dir, phase='test', splitid=args.datasetTest, transform=composed_transforms_ts) val_loader = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) # 2. model model = DeepLab(num_classes=2, num_domain=3, backbone='mobilenet', output_stride=args.out_stride, lam=args.lam).cuda() print('parameter numer:', sum([p.numel() for p in model.parameters()])) # load weights if args.resume: checkpoint = torch.load(args.resume) pretrained_dict = checkpoint['model_state_dict'] model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) print('Before ', model.centroids.data) model.centroids.data = centroids_init(model, args.data_dir, args.datasetTrain, composed_transforms_ts) print('Before ', model.centroids.data) # model.freeze_para() start_epoch = 0 start_iteration = 0 # 3. optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99)) trainer = Trainer.Trainer( cuda=cuda, model=model, lr=args.lr, lr_decrease_rate=args.lr_decrease_rate, train_loader=train_loader, val_loader=val_loader, optim=optim, out=args.out, max_epoch=args.max_epoch, stop_epoch=args.stop_epoch, interval_validate=args.interval_validate, batch_size=args.batch_size, ) trainer.epoch = start_epoch trainer.iteration = start_iteration trainer.train()