class ProbabilisticUnet(nn.Module): """ 概率UNet(https://arxiv.org/abs/1806.05034)实现。 input_channels:图像中的通道数(灰度为1,RGB为3) num_classes:要预测的类数 num_filters:是过滤器层数的列表一致性 latent_dim:隐空间的维度 no_cons_per_block:先验和后验(卷积)编码器中的每个块卷积编号 A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation. input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) num_classes: the number of classes to predict num_filters: is a list consisint of the amount of filters layer latent_dim: dimension of the latent space no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior """ def __init__(self, input_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=6, no_convs_fcomb=4, beta=10.0): super(ProbabilisticUnet, self).__init__() self.input_channels = input_channels # 输入图像通道数 self.num_classes = num_classes # 分割类别数 self.num_filters = num_filters # filter数 self.latent_dim = latent_dim # 隐空间维度 self.no_convs_per_block = 3 self.no_convs_fcomb = no_convs_fcomb self.initializers = {'w': 'he_normal', 'b': 'normal'} # 初始化 self.beta = beta self.z_prior_sample = 0 self.unet = Unet(self.input_channels, self.num_classes, self.num_filters, self.initializers, apply_last_layer=False, padding=True).to(device) self.prior = AxisAlignedConvGaussian( self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, ).to(device) self.posterior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, posterior=True).to(device) self.fcomb = Fcomb(self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, { 'w': 'orthogonal', 'b': 'normal' }, use_tile=True).to(device) def forward(self, patch, segm, training=True): """ 为patch构建先验隐空间,并通过UNet运行patch, 如果training=True,则还可以构造后方潜在空间 Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space """ if training: self.posterior_latent_space = self.posterior.forward(patch, segm) self.prior_latent_space = self.prior.forward(patch) self.unet_features = self.unet.forward(patch, False) def sample(self, testing=False): """ 通过根据先验样本进行重构来对切割进行采样 并将其与UNet特征相结合 Sample a segmentation by reconstructing from a prior sample and combining this with UNet features """ if testing == False: z_prior = self.prior_latent_space.rsample() self.z_prior_sample = z_prior else: #你可以选择是指样本还是平均值。 对于GED,取样非常重要。 #You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. #z_prior = self.prior_latent_space.base_dist.loc z_prior = self.prior_latent_space.sample() self.z_prior_sample = z_prior return self.fcomb.forward(self.unet_features, z_prior) def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): """ 从后验样本(解码后验样本)和UNet特征图重建分割 use_posterior_mean:使用posterior_mean代替对z_q的采样 compute_posterior:使用提供的样本或来自后潜在空间的样本 Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map use_posterior_mean: use posterior_mean instead of sampling z_q calculate_posterior: use a provided sample or sample from posterior latent space """ if use_posterior_mean: z_posterior = self.posterior_latent_space.loc else: if calculate_posterior: z_posterior = self.posterior_latent_space.rsample() return self.fcomb.forward(self.unet_features, z_posterior) def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): """ 计算后验KL(Q||P)和先验KL(Q||P)之间的KL散度 分析:通过分析或通过后验采样来计算KL compute_posterior:如果我们使用samapling来近似KL,则可以在此处采样或提供样本 Calculate the KL divergence between the posterior and prior KL(Q||P) analytic: calculate KL analytically or via sampling from the posterior calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample """ if analytic: #Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) else: if calculate_posterior: z_posterior = self.posterior_latent_space.rsample() log_posterior_prob = self.posterior_latent_space.log_prob( z_posterior) log_prior_prob = self.prior_latent_space.log_prob(z_posterior) kl_div = log_posterior_prob - log_prior_prob return kl_div def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): """ 计算P(Y|X)的边际似然函数下界 Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) z_posterior = self.posterior_latent_space.rsample() self.kl = torch.mean( self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)) #Here we use the posterior sample sampled above self.reconstruction = self.reconstruct( use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=z_posterior) reconstruction_loss = criterion(input=self.reconstruction, target=segm) self.reconstruction_loss = torch.sum(reconstruction_loss) self.mean_reconstruction_loss = torch.mean(reconstruction_loss) return -(self.reconstruction_loss + self.beta * self.kl)
class ProbabilisticUnet(nn.Module): """ A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation. input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) num_classes: the number of classes to predict num_filters: is a list consisint of the amount of filters layer latent_dim: dimension of the latent space no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior """ def __init__( self, input_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=6, no_convs_fcomb=4, beta=10.0, ): super(ProbabilisticUnet, self).__init__() self.input_channels = input_channels self.num_classes = num_classes self.num_filters = num_filters self.latent_dim = latent_dim self.no_convs_per_block = 3 self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} self.beta = beta self.z_prior_sample = 0 self.unet = Unet( self.input_channels, self.num_classes, self.num_filters, self.initializers, apply_last_layer=False, padding=True, ).to(device) self.prior = AxisAlignedConvGaussian( self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, ).to(device) self.posterior = AxisAlignedConvGaussian( self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, posterior=True, ).to(device) self.fcomb = Fcomb( self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, { "w": "orthogonal", "b": "normal" }, use_tile=True, ).to(device) def forward(self, patch, segm, training=True): """ Construct prior latent space for patch and run patch through UNet, in case training is True also construct posterior latent space """ if training: self.posterior_latent_space = self.posterior.forward(patch, segm) self.prior_latent_space = self.prior.forward(patch) self.unet_features = self.unet.forward(patch, False) def sample(self, testing=False): """ Sample a segmentation by reconstructing from a prior sample and combining this with UNet features """ if testing == False: z_prior = self.prior_latent_space.rsample() self.z_prior_sample = z_prior else: # You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. # z_prior = self.prior_latent_space.base_dist.loc z_prior = self.prior_latent_space.sample() self.z_prior_sample = z_prior return self.fcomb.forward(self.unet_features, z_prior) def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): """ Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map use_posterior_mean: use posterior_mean instead of sampling z_q calculate_posterior: use a provided sample or sample from posterior latent space """ if use_posterior_mean: z_posterior = self.posterior_latent_space.loc else: if calculate_posterior: z_posterior = self.posterior_latent_space.rsample() return self.fcomb.forward(self.unet_features, z_posterior) def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): """ Calculate the KL divergence between the posterior and prior KL(Q||P) analytic: calculate KL analytically or via sampling from the posterior calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample """ if analytic: # Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) else: if calculate_posterior: z_posterior = self.posterior_latent_space.rsample() log_posterior_prob = self.posterior_latent_space.log_prob( z_posterior) log_prior_prob = self.prior_latent_space.log_prob(z_posterior) kl_div = log_posterior_prob - log_prior_prob return kl_div def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): """ Calculate the evidence lower bound of the log-likelihood of P(Y|X) """ criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) z_posterior = self.posterior_latent_space.rsample() self.kl = torch.mean( self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)) # Here we use the posterior sample sampled above self.reconstruction = self.reconstruct( use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=z_posterior, ) reconstruction_loss = criterion(input=self.reconstruction, target=segm) self.reconstruction_loss = torch.sum(reconstruction_loss) self.mean_reconstruction_loss = torch.mean(reconstruction_loss) return -(self.reconstruction_loss + self.beta * self.kl)
def main(argv): """ IMAGES VALID: * 005-TS_13C08351_2-2014-02-12 12.22.44.ndpi | id : 77150767 * 024-12C07162_2A-2012-08-14-17.21.05.jp2 | id : 77150761 * 019-CP_12C04234_2-2012-08-10-12.49.26.jp2 | id : 77150809 IMAGES TEST: * 004-PF_08C11886_1-2012-08-09-19.05.53.jp2 | id : 77150623 * 011-TS_13C10153_3-2014-02-13 15.22.21.ndpi | id : 77150611 * 018-PF_07C18435_1-2012-08-17-00.55.09.jp2 | id : 77150755 """ with Cytomine.connect_from_cli(argv): parser = ArgumentParser() parser.add_argument("-b", "--batch_size", dest="batch_size", default=4, type=int) parser.add_argument("-j", "--n_jobs", dest="n_jobs", default=1, type=int) parser.add_argument("-e", "--epochs", dest="epochs", default=1, type=int) parser.add_argument("-d", "--device", dest="device", default="cpu") parser.add_argument("-o", "--overlap", dest="overlap", default=0, type=int) parser.add_argument("-t", "--tile_size", dest="tile_size", default=256, type=int) parser.add_argument("-z", "--zoom_level", dest="zoom_level", default=0, type=int) parser.add_argument("--lr", dest="lr", default=0.01, type=float) parser.add_argument("--init_fmaps", dest="init_fmaps", default=16, type=int) parser.add_argument("--data_path", "--dpath", dest="data_path", default=os.path.join(str(Path.home()), "tmp")) parser.add_argument("-w", "--working_path", "--wpath", dest="working_path", default=os.path.join(str(Path.home()), "tmp")) parser.add_argument("-s", "--save_path", dest="save_path", default=os.path.join(str(Path.home()), "tmp")) args, _ = parser.parse_known_args(argv) os.makedirs(args.save_path, exist_ok=True) os.makedirs(args.data_path, exist_ok=True) os.makedirs(args.working_path, exist_ok=True) # fetch annotations (filter val/test sets + other annotations) all_annotations = AnnotationCollection(project=77150529, showWKT=True, showMeta=True, showTerm=True).fetch() val_ids = {77150767, 77150761, 77150809} test_ids = {77150623, 77150611, 77150755} val_test_ids = val_ids.union(test_ids) train_collection = all_annotations.filter(lambda a: ( a.user in {55502856} and len(a.term) > 0 and a.term[0] in {35777351, 35777321, 35777459} and a.image not in val_test_ids)) val_rois = all_annotations.filter( lambda a: (a.user in {142954314} and a.image in val_ids and len( a.term) > 0 and a.term[0] in {154890363})) val_foreground = all_annotations.filter( lambda a: (a.user in {142954314} and a.image in val_ids and len( a.term) > 0 and a.term[0] in {154005477})) train_wsi_ids = list({an.image for an in all_annotations }.difference(val_test_ids)) val_wsi_ids = list(val_ids) download_path = os.path.join(args.data_path, "crops-{}".format(args.tile_size)) images = { _id: ImageInstance().fetch(_id) for _id in (train_wsi_ids + val_wsi_ids) } train_crops = [ AnnotationCrop(images[annot.image], annot, download_path, args.tile_size, zoom_level=args.zoom_level) for annot in train_collection ] val_crops = [ AnnotationCrop(images[annot.image], annot, download_path, args.tile_size, zoom_level=args.zoom_level) for annot in val_rois ] for crop in train_crops + val_crops: crop.download() np.random.seed(42) dataset = RemoteAnnotationTrainDataset( train_crops, seg_trans=segmentation_transform) loader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size, num_workers=args.n_jobs, worker_init_fn=worker_init) # network device = torch.device(args.device) unet = Unet(args.init_fmaps, n_classes=1) unet.train() unet.to(device) optimizer = Adam(unet.parameters(), lr=args.lr) loss_fn = BCEWithLogitsLoss(reduction="mean") results = { "train_losses": [], "val_losses": [], "val_metrics": [], "save_path": [] } for e in range(args.epochs): print("########################") print(" Epoch {}".format(e)) print("########################") epoch_losses = list() unet.train() for i, (x, y) in enumerate(loader): x, y = (t.to(device) for t in [x, y]) y_pred = unet.forward(x) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() epoch_losses = [loss.detach().cpu().item()] + epoch_losses[:5] print("{} - {:1.5f}".format(i, np.mean(epoch_losses))) results["train_losses"].append(epoch_losses[0]) unet.eval() # validation val_losses = np.zeros(len(val_rois), dtype=np.float) val_roc_auc = np.zeros(len(val_rois), dtype=np.float) val_cm = np.zeros([len(val_rois), 2, 2], dtype=np.int) for i, roi in enumerate(val_crops): foregrounds = find_intersecting_annotations( roi.annotation, val_foreground) with torch.no_grad(): y_pred, y_true = predict_roi( roi, foregrounds, unet, device, in_trans=transforms.ToTensor(), batch_size=args.batch_size, tile_size=args.tile_size, overlap=args.overlap, n_jobs=args.n_jobs, zoom_level=args.zoom_level) val_losses[i] = metrics.log_loss(y_true.flatten(), y_pred.flatten()) val_roc_auc[i] = metrics.roc_auc_score(y_true.flatten(), y_pred.flatten()) val_cm[i] = metrics.confusion_matrix( y_true.flatten().astype(np.uint8), (y_pred.flatten() > 0.5).astype(np.uint8)) print("------------------------------") print("Epoch {}:".format(e)) val_loss = np.mean(val_losses) roc_auc = np.mean(val_roc_auc) print("> val_loss: {:1.5f}".format(val_loss)) print("> roc_auc : {:1.5f}".format(roc_auc)) cm = np.sum(val_cm, axis=0) cnt = np.sum(val_cm) print("CM at 0.5 threshold") print("> {:3.2f}% {:3.2f}%".format(100 * cm[0, 0] / cnt, 100 * cm[0, 1] / cnt)) print("> {:3.2f}% {:3.2f}%".format(100 * cm[1, 0] / cnt, 100 * cm[1, 1] / cnt)) print("------------------------------") filename = "{}_e_{}_val_{:0.4f}_roc_{:0.4f}_z{}_s{}.pth".format( datetime.now().timestamp(), e, val_loss, roc_auc, args.zoom_level, args.tile_size) torch.save(unet.state_dict(), os.path.join(args.save_path, filename)) results["val_losses"].append(val_loss) results["val_metrics"].append(roc_auc) results["save_path"].append(filename) return results