def create_model(ema=False): # Network definition net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) model = net.cuda() if ema: for param in model.parameters(): param.detach_() return model
def main(): args = get_args() # dataset db_test = ABUS(base_dir=args.root_path, split='test') testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) args.testloader = testloader # network if args.arch == 'vnet': model = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=True, use_tm=args.use_tm) elif args.arch == 'd2unet': model = D2UNet() else: raise (NotImplementedError('model {} not implement'.format(args.arch))) model = model.cuda() if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_pre = checkpoint['best_pre'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint (epoch {})".format( checkpoint['epoch'])) # --- saving path --- if 'best' in args.resume: file_name = 'model_best_' + str(checkpoint['epoch']) elif 'check' in args.resume: file_name = 'checkpoint_{}_result'.format(checkpoint['epoch']) if args.save is not None: save_path = os.path.join(args.save, file_name) else: save_path = os.path.join(os.path.dirname(args.resume), file_name) if os.path.exists(save_path): shutil.rmtree(save_path) os.makedirs(save_path, exist_ok=True) test_all_case(model, args.testloader, num_classes=args.num_classes, patch_size=(64, 128, 128), stride_xy=64, stride_z=64, save_result=True, test_save_path=save_path)
def test_calculate_metric(epoch_num): net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False) # net = torch.nn.DataParallel(net) net = net.cuda() save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') net.load_state_dict(torch.load(save_mode_path)) print("init weight from {}".format(save_mode_path)) net.eval() avg_metric = test_all_case(net, image_list, num_classes=num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=test_save_path) return avg_metric
shutil.rmtree(snapshot_path + '/code') shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git', '__pycache__'])) logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False) net = net.cuda() db_train = LiverTumor(base_dir=train_data_path, split='train', transform=transforms.Compose([ RandomRotFlip(), RandomCrop(patch_size), ToTensor(), ])) def worker_init_fn(worker_id): random.seed(args.seed + worker_id) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True,
def test_calculate_metric(iter_nums): if args.net == 'vnet': net = VNet(n_channels=1, num_classes=args.num_classes, normalization='batchnorm', has_dropout=False) elif args.net == 'unet': net = UNet3D(in_channels=1, num_classes=args.num_classes) elif args.net == 'segtran': get_default(args, 'num_modes', default_settings, -1, [args.net, 'num_modes', args.in_fpn_layers]) if args.segtran_type == '25d': set_segtran25d_config(args) net = Segtran25d(config25d) else: set_segtran3d_config(args) net = Segtran3d(config3d) net.cuda() net.eval() preproc_fn = None if not args.checkpoint_dir: if args.vis_mode is not None: visualize_model(net, args.vis_mode) return if args.eval_robustness: eval_robustness(net, testloader, args.aug_degree) return for iter_num in iter_nums: if args.checkpoint_dir: checkpoint_path = os.path.join(args.checkpoint_dir, 'iter_' + str(iter_num) + '.pth') load_model(net, args, checkpoint_path) if args.vis_mode is not None: visualize_model(net, args.vis_mode) continue if args.eval_robustness: eval_robustness(net, testloader, args.aug_degree) continue save_result = not args.test_interp if save_result: test_save_paths = [] test_save_dirs = [] test_save_dir = "%s-%s-%s-%d" % (args.net, args.job_name, timestamp, iter_num) test_save_path = "../prediction/%s" % (test_save_dir) if not os.path.exists(test_save_path): os.makedirs(test_save_path) test_save_dirs.append(test_save_dir) test_save_paths.append(test_save_path) else: test_save_paths = [None] test_save_dirs = [None] # No need to use dataloader to pass data, # as one 3D image is split into many patches to do segmentation. allcls_avg_metric = test_all_cases( net, db_test, task_name=args.task_name, net_type=args.net, num_classes=args.num_classes, batch_size=args.batch_size, orig_patch_size=args.orig_patch_size, input_patch_size=args.input_patch_size, stride_xy=args.orig_patch_size[0] // 2, stride_z=args.orig_patch_size[2] // 2, save_result=save_result, test_save_path=test_save_paths[0], preproc_fn=preproc_fn, test_interp=args.test_interp, has_mask=has_mask) print("%d scores:" % iter_num) for cls in range(1, args.num_classes): dice, jc, hd, asd = allcls_avg_metric[cls - 1] print('%d: dice: %.3f, jc: %.3f, hd: %.3f, asd: %.3f' % (cls, dice, jc, hd, asd)) if save_result: FNULL = open(os.devnull, 'w') # Currently only save hard predictions. for pred_type, test_save_dir, test_save_path in zip( ('hard', ), test_save_dirs, test_save_paths): do_tar = subprocess.run( ["tar", "cvf", "%s.tar" % test_save_dir, test_save_dir], cwd="../prediction", stdout=FNULL, stderr=subprocess.STDOUT) # print(do_tar) print("{} tarball:\n{}.tar".format( pred_type, os.path.abspath(test_save_path))) return allcls_avg_metric
def main(): ################### # init parameters # ################### args = get_args() # training path train_data_path = args.root_path # writer idx = args.save.rfind('/') log_dir = args.writer_dir + args.save[idx:] writer = SummaryWriter(log_dir) batch_size = args.batch_size * args.ngpu max_iterations = args.max_iterations base_lr = args.base_lr #patch_size = (112, 112, 112) #patch_size = (160, 160, 160) patch_size = (64, 128, 128) num_classes = 2 # random if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) ## make logger file if os.path.exists(args.save): shutil.rmtree(args.save) os.makedirs(args.save, exist_ok=True) snapshot_path = args.save logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True, use_tm=args.use_tm) net = net.cuda() #db_train = LAHeart(base_dir=train_data_path, # split='train', # transform = transforms.Compose([ # RandomRotFlip(), # RandomCrop(patch_size), # ToTensor(), # ])) db_train = ABUS(base_dir=args.root_path, split='train', use_dismap=args.use_dismap, transform = transforms.Compose([RandomRotFlip(use_dismap=args.use_dismap), RandomCrop(patch_size, use_dismap=args.use_dismap), ToTensor(use_dismap=args.use_dismap)])) def worker_init_fn(worker_id): random.seed(args.seed+worker_id) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) net.train() optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) #gdl = GeneralizedDiceLoss() logging.info("{} itertations per epoch".format(len(trainloader))) iter_num = 0 alpha = 1.0 max_epoch = max_iterations//len(trainloader)+1 lr_ = base_lr net.train() for epoch_num in tqdm(range(max_epoch), ncols=70): time1 = time.time() for i_batch, sampled_batch in enumerate(trainloader): time2 = time.time() # print('fetch data cost {}'.format(time2-time1)) volume_batch, label_batch, dis_map_batch = sampled_batch['image'], sampled_batch['label'], sampled_batch['dis_map'] volume_batch, label_batch, dis_map_batch = volume_batch.cuda(), label_batch.cuda(), dis_map_batch.cuda() #print('volume_batch.shape: ', volume_batch.shape) if args.use_tm: outputs, tm = net(volume_batch) tm = torch.sigmoid(tm) else: outputs = net(volume_batch) #print('volume_batch.shape: ', volume_batch.shape) #print('outputs.shape, ', outputs.shape) loss_seg = F.cross_entropy(outputs, label_batch) outputs_soft = F.softmax(outputs, dim=1) #print(outputs_soft.shape) #print(label_batch.shape) #loss_seg_dice = gdl(outputs_soft, label_batch) loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) #with torch.no_grad(): # # defalut using compute_sdf; however, compute_sdf1_1 is also worth to try; # gt_sdf_npy = compute_sdf(label_batch.cpu().numpy(), outputs_soft.shape) # gt_sdf = torch.from_numpy(gt_sdf_npy).float().cuda(outputs_soft.device.index) # print('gt_sdf.shape: ', gt_sdf.shape) #loss_boundary = boundary_loss(outputs_soft, gt_sdf) #print('dis_map.shape: ', dis_map_batch.shape) loss_boundary = boundary_loss(outputs_soft, dis_map_batch) if args.use_tm: loss_threshold = threshold_loss(outputs_soft[:, 1, :, :, :], tm[:, 0, ...], label_batch == 1) loss_th = (0.1 * loss_seg + 0.9 * loss_seg_dice) + 3 * loss_threshold loss = alpha*(loss_th) + (1 - alpha) * loss_boundary else: loss = alpha * loss_seg_dice + (1-alpha) * loss_boundary optimizer.zero_grad() loss.backward() optimizer.step() out = outputs_soft.max(1)[1] dice = GeneralizedDiceLoss.dice_coeficient(out, label_batch) iter_num = iter_num + 1 writer.add_scalar('train/lr', lr_, iter_num) writer.add_scalar('train/loss_seg', loss_seg, iter_num) writer.add_scalar('train/loss_seg_dice', loss_seg_dice, iter_num) writer.add_scalar('train/alpha', alpha, iter_num) writer.add_scalar('train/loss', loss, iter_num) writer.add_scalar('train/dice', dice, iter_num) if args.use_tm: writer.add_scalar('train/loss_threshold', loss_threshold, iter_num) if args.use_dismap: writer.add_scalar('train/loss_dis', loss_boundary, iter_num) logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) logging.info('iteration %d : alpha : %f' % (iter_num, alpha)) if iter_num % 50 == 0: image = volume_batch[0, 0:1, :, 30:71:10, :].permute(2,0,1,3) image = (image + 0.5) * 0.5 grid_image = make_grid(image, 5) writer.add_image('train/Image', grid_image, iter_num) #outputs_soft = F.softmax(outputs, 1) #batchsize x num_classes x w x h x d image = outputs_soft[0, 1:2, :, 30:71:10, :].permute(2,0,1,3) grid_image = make_grid(image, 5, normalize=False) grid_image = grid_image.cpu().detach().numpy().transpose((1,2,0)) gt = label_batch[0, :, 30:71:10, :].unsqueeze(0).permute(2,0,1,3) grid_gt = make_grid(gt, 5, normalize=False) grid_gt = grid_gt.cpu().detach().numpy().transpose((1,2,0)) image_tm = dis_map_batch[0, :, :, 30:71:10, :].permute(2,0,1,3) #image_tm = tm[0, :, :, 30:71:10, :].permute(2,0,1,3) grid_tm = make_grid(image_tm, 5, normalize=False) grid_tm = grid_tm.cpu().detach().numpy().transpose((1,2,0)) fig = plt.figure() ax = fig.add_subplot(311) ax.imshow(grid_gt[:, :, 0], 'gray') ax = fig.add_subplot(312) cs = ax.imshow(grid_image[:, :, 0], 'hot', vmin=0., vmax=1.) fig.colorbar(cs, ax=ax, shrink=0.9) ax = fig.add_subplot(313) cs = ax.imshow(grid_tm[:, :, 0], 'hot', vmin=0, vmax=1.) fig.colorbar(cs, ax=ax, shrink=0.9) writer.add_figure('train/prediction_results', fig, iter_num) fig.clear() ## change lr if iter_num % 5000 == 0: lr_ = base_lr * 0.1 ** (iter_num // 5000) for param_group in optimizer.param_groups: param_group['lr'] = lr_ if iter_num % 1000 == 0: save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') torch.save(net.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) if iter_num > max_iterations: break time1 = time.time() alpha -= 0.005 if alpha <= 0.01: alpha = 0.01 if iter_num > max_iterations: break save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') torch.save(net.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) writer.close()
def main(): ################### # init parameters # ################### args = get_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu # training path train_data_path = args.root_path # writer idx = args.save.rfind('/') log_dir = args.writer_dir + args.save[idx:] writer = SummaryWriter(log_dir) # random if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) ## make logger file if os.path.exists(args.save): shutil.rmtree(args.save) os.makedirs(args.save, exist_ok=True) snapshot_path = args.save logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) # network if args.arch == 'vnet': model = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=True, use_tm=args.use_tm) elif args.arch == 'd2unet': model = D2UNet() else: raise (NotImplementedError('model {} not implement'.format(args.arch))) model = model.cuda() # dataset patch_size = (64, 128, 128) batch_size = args.ngpu * args.batch_size def worker_init_fn(worker_id): random.seed(args.seed + worker_id) db_train = ABUS(base_dir=args.root_path, split='val', fold=args.fold, transform=transforms.Compose( [RandomRotFlip(), RandomCrop(patch_size), ToTensor()])) db_val = ABUS(base_dir=args.root_path, split='val', fold=args.fold, transform=transforms.Compose( [CenterCrop(patch_size), ToTensor()])) train_loader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) val_loader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, worker_init_fn=worker_init_fn) # optimizer lr = args.lr optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay) lr_scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epoch, len(train_loader)) # training logging.info('--- start training ---') best_pre = 0. nTrain = len(db_train) for epoch in range(args.start_epoch, args.epoch + 1): train(args, epoch, model, train_loader, optimizer, writer, lr_scheduler) dice = val(args, epoch, model, val_loader, writer) is_best = False if dice > best_pre: is_best = True best_pre = dice save_checkpoint( { 'epoch': epoch, 'state_dict': model.state_dict(), 'best_pre': best_pre }, is_best, args.save, args.arch) writer.close()
def main(): ################### # init parameters # ################### args = get_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu # training path train_data_path = args.root_path # writer idx = args.save.rfind('/') log_dir = args.writer_dir + args.save[idx:] writer = SummaryWriter(log_dir) batch_size = args.batch_size * args.ngpu max_iterations = args.max_iterations base_lr = args.base_lr patch_size = (64, 128, 128) num_classes = 2 # random if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) ## make logger file if os.path.exists(args.save): shutil.rmtree(args.save) os.makedirs(args.save, exist_ok=True) snapshot_path = args.save logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) # network if args.arch == 'vnet': net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True, use_tm=args.use_tm) elif args.arch == 'd2unet': net = D2UNet() else: raise(NotImplementedError('model {} not implement'.format(args.arch))) net = net.cuda() # dataset def worker_init_fn(worker_id): random.seed(args.seed+worker_id) db_train = ABUS(base_dir=args.root_path, split='train', fold=args.fold, transform = transforms.Compose([RandomRotFlip(), RandomCrop(patch_size), ToTensor()])) db_val = ABUS(base_dir=args.root_path, split='val', fold=args.fold, transform = transforms.Compose([CenterCrop(patch_size), ToTensor()])) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) # optimizer optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) gdl = GeneralizedDiceLoss() logging.info("{} itertations per epoch".format(len(trainloader))) # training iter_num = 0 max_epoch = max_iterations//len(trainloader)+1 lr_ = base_lr net.train() for epoch_num in tqdm(range(max_epoch), ncols=70): for i_batch, sampled_batch in enumerate(trainloader): volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() if args.use_tm: outputs, tm = net(volume_batch) tm = torch.sigmoid(tm) else: outputs = net(volume_batch) loss_seg = F.cross_entropy(outputs, label_batch) outputs_soft = F.softmax(outputs, dim=1) loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) if args.use_tm: loss_threshold = threshold_loss(outputs_soft[:, 1, :, :, :], tm[:, 0, ...], label_batch == 1) loss = loss_seg_dice + 3 * loss_threshold else: loss = loss_seg_dice optimizer.zero_grad() loss.backward() optimizer.step() # visualization on tensorboard out = outputs_soft.max(1)[1] dice = GeneralizedDiceLoss.dice_coeficient(out, label_batch) iter_num = iter_num + 1 writer.add_scalar('train/lr', lr_, iter_num) writer.add_scalar('train/loss_seg', loss_seg, iter_num) writer.add_scalar('train/loss_seg_dice', loss_seg_dice, iter_num) writer.add_scalar('train/loss', loss, iter_num) writer.add_scalar('train/dice', dice, iter_num) if args.use_tm: writer.add_scalar('train/loss_threshold', loss_threshold, iter_num) logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) if iter_num % 50 == 0: nrow = 5 image = volume_batch[0, 0:1, :, 30:71:10, :].permute(2,0,1,3) image = (image + 0.5) * 0.5 grid_image = make_grid(image, nrow=nrow) writer.add_image('train/Image', grid_image, iter_num) #outputs_soft = F.softmax(outputs, 1) #batchsize x num_classes x w x h x d image = outputs_soft[0, 1:2, :, 30:71:10, :].permute(2,0,1,3) grid_image = make_grid(image, nrow=nrow, normalize=False) grid_image = grid_image.cpu().detach().numpy().transpose((1,2,0)) gt = label_batch[0, :, 30:71:10, :].unsqueeze(0).permute(2,0,1,3) grid_gt = make_grid(gt, nrow=nrow, normalize=False) grid_gt = grid_gt.cpu().detach().numpy().transpose((1,2,0)) if args.use_tm: image_tm = tm[0, :, :, 30:71:10, :].permute(2,0,1,3) else: image_tm = gt grid_tm = make_grid(image_tm, nrow=nrow, normalize=False) grid_tm = grid_tm.cpu().detach().numpy().transpose((1,2,0)) fig = plt.figure() ax = fig.add_subplot(311) ax.imshow(grid_gt[:, :, 0], 'gray') ax = fig.add_subplot(312) cs = ax.imshow(grid_image[:, :, 0], 'hot', vmin=0., vmax=1.) fig.colorbar(cs, ax=ax, shrink=0.9) ax = fig.add_subplot(313) cs = ax.imshow(grid_tm[:, :, 0], 'hot', vmin=0, vmax=1.) fig.colorbar(cs, ax=ax, shrink=0.9) writer.add_figure('train/prediction_results', fig, iter_num) fig.clear() ## change lr if iter_num % 2500 == 0: lr_ = base_lr * 0.1 ** (iter_num // 2500) for param_group in optimizer.param_groups: param_group['lr'] = lr_ if iter_num % 1000 == 0 and iter_num > 5000: save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') torch.save(net.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) if iter_num > max_iterations: break time1 = time.time() if iter_num > max_iterations: break save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') torch.save(net.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) writer.close()
n_classes=args.num_classes, normalization='batchnorm', has_dropout=True) elif args.net == 'unet': net = UNet3D(in_channels=1, n_classes=args.num_classes) elif args.net == 'segtran': if args.segtran_type == '3d': set_segtran3d_config(args) net = Segtran3d(config3d) else: set_segtran25d_config(args) net = Segtran25d(config25d) else: breakpoint() net.cuda() if args.opt == 'sgd': optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) elif args.opt == 'adam': optimizer = optim.Adam(net.parameters(), lr=base_lr, weight_decay=0.0001) elif args.opt == 'adamw': optimizer = init_optimizer(net, max_epoch, len(trainloader)) if args.checkpoint_path is not None: iter_num = load_model(net, optimizer, args, args.checkpoint_path) start_epoch = math.ceil(iter_num / len(trainloader))
def main(): ################### # init parameters # ################### args = get_args() # training path train_data_path = args.root_path # writer idx = args.save.rfind('/') log_dir = args.writer_dir + args.save[idx:] writer = SummaryWriter(log_dir) batch_size = args.batch_size * args.ngpu max_iterations = args.max_iterations base_lr = args.base_lr patch_size = (112, 112, 80) num_classes = 2 # random if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) ## make logger file if os.path.exists(args.save): shutil.rmtree(args.save) os.makedirs(args.save, exist_ok=True) snapshot_path = args.save logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) net = net.cuda() db_train = LAHeart(base_dir=train_data_path, split='train', num=16, transform = transforms.Compose([ RandomRotFlip(), RandomCrop(patch_size), ToTensor(), ])) db_test = LAHeart(base_dir=train_data_path, split='test', transform = transforms.Compose([ CenterCrop(patch_size), ToTensor() ])) def worker_init_fn(worker_id): random.seed(args.seed+worker_id) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) net.train() optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) logging.info("{} itertations per epoch".format(len(trainloader))) iter_num = 0 max_epoch = max_iterations//len(trainloader)+1 lr_ = base_lr net.train() for epoch_num in tqdm(range(max_epoch), ncols=70): time1 = time.time() for i_batch, sampled_batch in enumerate(trainloader): time2 = time.time() # print('fetch data cost {}'.format(time2-time1)) volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() outputs = net(volume_batch) loss_seg = F.cross_entropy(outputs, label_batch) outputs_soft = F.softmax(outputs, dim=1) loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) loss = 0.5*(loss_seg+loss_seg_dice) optimizer.zero_grad() loss.backward() optimizer.step() iter_num = iter_num + 1 writer.add_scalar('lr', lr_, iter_num) writer.add_scalar('loss/loss_seg', loss_seg, iter_num) writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num) writer.add_scalar('loss/loss', loss, iter_num) logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) if iter_num % 50 == 0: image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1) grid_image = make_grid(image, 5, normalize=True) writer.add_image('train/Image', grid_image, iter_num) outputs_soft = F.softmax(outputs, 1) image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=False) writer.add_image('train/Predicted_label', grid_image, iter_num) image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=False) writer.add_image('train/Groundtruth_label', grid_image, iter_num) ## change lr if iter_num % 2500 == 0: lr_ = base_lr * 0.1 ** (iter_num // 2500) for param_group in optimizer.param_groups: param_group['lr'] = lr_ if iter_num % 1000 == 0: save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') torch.save(net.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) if iter_num > max_iterations: break time1 = time.time() if iter_num > max_iterations: break save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') torch.save(net.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) writer.close()
def main(): ################### # init parameters # ################### args = get_args() # training path train_data_path = args.root_path # writer idx = args.save.rfind('/') log_dir = args.writer_dir + args.save[idx:] writer = SummaryWriter(log_dir) batch_size = args.batch_size * args.ngpu max_iterations = args.max_iterations base_lr = args.base_lr patch_size = (112, 112, 80) num_classes = 2 # random if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) ## make logger file if os.path.exists(args.save): shutil.rmtree(args.save) os.makedirs(args.save, exist_ok=True) snapshot_path = args.save logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) #training set db_train = LAHeart(base_dir=train_data_path, split='train', num=16, transform=transforms.Compose([ RandomRotFlip(), RandomCrop(patch_size), ToTensor() ])) def worker_init_fn(worker_id): random.seed(args.seed + worker_id) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) net = net.cuda() net.train() optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) logging.info("{} itertations per epoch".format(len(trainloader))) iter_num = 0 alpha = 1.0 max_epoch = max_iterations // len(trainloader) + 1 lr_ = base_lr net.train() for epoch_num in tqdm(range(max_epoch), ncols=70): time1 = time.time() for i_batch, sampled_batch in enumerate(trainloader): time2 = time.time() # print('fetch data cost {}'.format(time2-time1)) # volume_batch.shape=(b,1,x,y,z) label_patch.shape=(b,x,y,z) volume_batch, label_batch = sampled_batch['image'], sampled_batch[ 'label'] volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() outputs = net(volume_batch) loss_seg = F.cross_entropy(outputs, label_batch) outputs_soft = F.softmax(outputs, dim=1) loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) # compute gt_signed distance function and boundary loss with torch.no_grad(): # defalut using compute_sdf; however, compute_sdf1_1 is also worth to try; gt_sdf_npy = compute_sdf(label_batch.cpu().numpy(), outputs_soft.shape) gt_sdf = torch.from_numpy(gt_sdf_npy).float().cuda( outputs_soft.device.index) # show signed distance map for debug # import matplotlib.pyplot as plt # plt.figure() # plt.subplot(121), plt.imshow(gt_sdf_npy[0,1,:,:,40]), plt.colorbar() # plt.subplot(122), plt.imshow(np.uint8(label_batch.cpu().numpy()[0,:,:,40]>0)), plt.colorbar() # plt.show() loss_boundary = boundary_loss(outputs_soft, gt_sdf) loss = alpha * (loss_seg + loss_seg_dice) + (1 - alpha) * loss_boundary optimizer.zero_grad() loss.backward() optimizer.step() iter_num = iter_num + 1 writer.add_scalar('lr', lr_, iter_num) writer.add_scalar('loss/loss_seg', loss_seg, iter_num) writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num) writer.add_scalar('loss/loss_boundary', loss_boundary, iter_num) writer.add_scalar('loss/loss', loss, iter_num) writer.add_scalar('loss/alpha', alpha, iter_num) logging.info('iteration %d : alpha : %f' % (iter_num, alpha)) logging.info('iteration %d : loss_seg_dice : %f' % (iter_num, loss_seg_dice.item())) logging.info('iteration %d : loss_boundary : %f' % (iter_num, loss_boundary.item())) logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) if iter_num % 2 == 0: image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=True) writer.add_image('train/Image', grid_image, iter_num) image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=False) writer.add_image('train/Predicted_label', grid_image, iter_num) image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute( 3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=False) writer.add_image('train/Groundtruth_label', grid_image, iter_num) image = gt_sdf[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=True) writer.add_image('train/gt_sdf', grid_image, iter_num) ## change lr if iter_num % 2500 == 0: lr_ = base_lr * 0.1**(iter_num // 2500) for param_group in optimizer.param_groups: param_group['lr'] = lr_ if iter_num % 1000 == 0: save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') torch.save(net.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) if iter_num > max_iterations: break time1 = time.time() alpha -= 0.01 if alpha <= 0.01: alpha = 0.01 if iter_num > max_iterations: break save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations + 1) + '.pth') torch.save(net.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) writer.close()
def experiment(exp_identifier, max_epoch, training_data, testing_data, batch_size=2, supervised_only=False, K=2, T=0.5, alpha=1, mixup_mode='all', Lambda=1, Lambda_ramp=None, base_lr=0.01, change_lr=None, aug_factor=1, from_saved=None, always_do_validation=True, decay=0): ''' max_epoch: epochs to run. Going through labeled data once is one epoch. batch_size: batch size of labeled data. Unlabeled data is of the same size. training_data: data for train_epoch, list of dicts of numpy array. training_data: data for validation, list of dicts of numpy array. supervised_only: if True, only do supervised training on LABELLED_INDEX; otherwise, use both LABELLED_INDEX and UNLABELLED_INDEX Hyperparameters --------------- K: repeats of each unlabelled data T: temperature of sharpening alpha: mixup hyperparameter of beta distribution mixup_mode: how mixup is performed -- '__': no mix up 'ww': x and u both mixed up with w(x+u) 'xx': both with x 'xu': x with x, u with u 'uu': both with u ... _ means no, x means with x, u means with u, w means with w(x+u) Lambda: loss = loss_x + Lambda * loss_u, relative weight for unsupervised loss base_lr: initial learning rate Lambda_ramp: callable or None. Lambda is ignored if this is not None. In this case, Lambda = Lambda_ramp(epoch). change_lr: dict, {epoch: change_multiplier} ''' print( f"Experiment {exp_identifier}: max_epoch = {max_epoch}, batch_size = {batch_size}, supervised_only = {supervised_only}," f"K = {K}, T = {T}, alpha = {alpha}, mixup_mode = {mixup_mode}, Lambda = {Lambda}, Lambda_ramp = {Lambda_ramp}, base_lr = {base_lr}, aug_factor = {aug_factor}." ) net = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=True) eval_net = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=True) if from_saved is not None: net.load_state_dict(torch.load(from_saved)) if GPU: net = net.cuda() eval_net.cuda() ## eval_net is not updating for param in eval_net.parameters(): param.detach_() net.train() eval_net.train() optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) x_criterion = soft_cross_entropy #supervised loss is 0.5*(x_criterion + dice_loss) u_criterion = nn.MSELoss() #unsupervised loss training_losses = [] testing_losses = [] testing_accuracy = [] #dice accuracy patch_size = (112, 112, 80) testing_data = [ shape_transform(CenterCrop(patch_size)(sample)) for sample in testing_data ] t0 = time.time() lr = base_lr for epoch in range(max_epoch): labelled_index = np.random.permutation(LABELLED_INDEX) unlabelled_index = np.random.permutation( UNLABELLED_INDEX)[:len(labelled_index)] labelled_data = [training_data[i] for i in labelled_index] unlabelled_data = [training_data[i] for i in unlabelled_index] #size = 16 ##data transformation: rotation, flip, random_crop labelled_data = [ shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample))) for sample in labelled_data ] unlabelled_data = [ shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample))) for sample in unlabelled_data ] if Lambda_ramp is not None: Lambda = Lambda_ramp(epoch) print(f"Lambda ramp: Lambda = {Lambda}") if change_lr is not None: if epoch in change_lr: lr_ = lr * change_lr[epoch] print( f"Learning rate decay at epoch {epoch}, from {lr} to {lr_}" ) lr = lr_ #change learning rate. for param_group in optimizer.param_groups: param_group['lr'] = lr_ training_loss = train_epoch(net=net, eval_net=eval_net, labelled_data=labelled_data, unlabelled_data=unlabelled_data, batch_size=batch_size, supervised_only=supervised_only, optimizer=optimizer, x_criterion=x_criterion, u_criterion=u_criterion, K=K, T=T, alpha=alpha, mixup_mode=mixup_mode, Lambda=Lambda, aug_factor=aug_factor, decay=decay) training_losses.append(training_loss) if always_do_validation or epoch % 50 == 0: testing_dice_loss, accuracy = validation(net=net, testing_data=testing_data, x_criterion=x_criterion) testing_losses.append(testing_dice_loss) testing_accuracy.append(accuracy) print( f"Epoch {epoch+1}/{max_epoch}, time used: {time.time()-t0:.2f}, training loss: {training_loss:.6f}, testing dice_loss: {testing_dice_loss:.6f}, testing accuracy: {100.0*accuracy:.2f}% " ) save_path = f"../saved/{exp_identifier}.pth" torch.save(net.state_dict(), save_path) print(f"Experiment {exp_identifier} finished. Model saved as {save_path}") return training_losses, testing_losses, testing_accuracy