def __init__(self, use_affine_ex: bool = True, **kwargs): super().__init__(**kwargs) self.use_affine_ex = use_affine_ex # decoder parts # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal self.intensity_flow_components = ComposeTransformModule( [LearnedAffineTransform(), Spline(1)]) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ] # realnvp or so for x self._build_image_flow()
def _get_transformed_x_dist(self, latent, ctx=None): x_pred_dist = self.decoder.predict( latent, ctx) # returns a normal dist with mean of the predicted image if self.laplace_likelihood: x_base_dist = Laplace(self.x_base_loc, self.x_base_scale).to_event(3) else: x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event( 3) # 3 dimensions starting from right dep. preprocess_transform = self._get_preprocess_transforms() if isinstance(x_pred_dist, MultivariateNormal) or isinstance( x_pred_dist, LowRankMultivariateNormal): chol_transform = LowerCholeskyAffine(x_pred_dist.loc, x_pred_dist.scale_tril) reshape_transform = ReshapeTransform(self.img_shape, (np.prod(self.img_shape), )) x_reparam_transform = ComposeTransform( [reshape_transform, chol_transform, reshape_transform.inv]) elif isinstance(x_pred_dist, Independent): x_pred_dist = x_pred_dist.base_dist x_reparam_transform = AffineTransform(x_pred_dist.loc, x_pred_dist.scale, 3) else: raise ValueError(f'{x_pred_dist} not valid.') return TransformedDistribution( x_base_dist, ComposeTransform([x_reparam_transform, preprocess_transform]))
def __init__(self, use_affine_ex: bool = True, **kwargs): super().__init__(**kwargs) self.use_affine_ex = use_affine_ex # decoder parts # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal intensity_net = DenseNN(1, [1], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.intensity_flow_components = ConditionalAffineTransform( context_nn=intensity_net, event_dim=0) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ] # build flow as s_affine_w * t * e_s + b -> depends on t though # realnvp or so for x self._build_image_flow()
def __init__(self, **kwargs): super().__init__(**kwargs) # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal intensity_net = DenseNN(1, [1], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.intensity_flow_components = ConditionalAffineTransform( context_nn=intensity_net, event_dim=0) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ]
def infer_intensity_base(self, thickness, intensity): intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale) cond_intensity_transforms = ComposeTransform( ConditionalTransformedDistribution( intensity_base_dist, self.intensity_flow_transforms).condition( thickness).transforms) return cond_intensity_transforms.inv(intensity)
def __init__(self, use_affine_ex=True, **kwargs): super.__init__(**kwargs) self.num_scales = 2 self.register_buffer("glasses_base_loc", torch.zeros([ 1, ], requires_grad=False)) self.register_buffer("glasses_base_scale", torch.ones([ 1, ], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_loc", torch.zeros([], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_scale", torch.ones([], requires_grad=False)) self.glasses_flow_components = ComposeTransformModule([Spline(1)]) self.glasses_flow_constraint_transforms = ComposeTransform( [self.glasses_flow_lognorm, SigmoidTransform()]) self.glasses_flow_transforms = ComposeTransform([ self.glasses_flow_components, self.glasses_flow_constraint_transforms ]) glasses_base_dist = Normal(self.glasses_base_loc, self.glasses_base_scale).to_event(1) self.glasses_dist = TransformedDistribution( glasses_base_dist, self.glasses_flow_transforms) glasses_ = pyro.sample("glasses_", self.glasses_dist) glasses = pyro.sample("glasses", dist.Bernoulli(glasses_)) glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_) self.x_transforms = self._build_image_flow() self.register_buffer("x_base_loc", torch.zeros([1, 64, 64], requires_grad=False)) self.register_buffer("x_base_scale", torch.ones([1, 64, 64], requires_grad=False)) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms).inv cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms) x = pyro.sample("x", cond_x_dist) return x, glasses
def generate(self, x, num_particles): z_dist = self.encoder.predict(x) z = z_dist.sample() x_pred_dist = self.decoder.predict(z) x_base_dist = dist.Normal( torch.zeros_like(x, requires_grad=False).view(x.shape[0], -1), torch.ones_like(x, requires_grad=False).view(x.shape[0], -1), ).to_event(1) if 'normal' in self.decoder_output or \ self.decoder_output == 'deepvar' or \ self.decoder_output == 'deepmean': transform = AffineTransform(x_pred_dist.mean, x_pred_dist.stddev, 1) elif self.decoder_output == 'low_rank_mvn': # print(x_pred_dist.loc.shape) # print(x_pred_dist.loc) # print(x_pred_dist.scale_tril.shape) # print(x_pred_dist.scale_tril) transform = LowerCholeskyAffine(x_pred_dist.loc, x_pred_dist.scale_tril) else: raise Exception('Unknown decoder output') x_dist = dist.TransformedDistribution(x_base_dist, ComposeTransform([transform])) recons = [] for i in range(num_particles): recon = pyro.sample('x', x_dist).view(x.shape[0], self.shape, 3) recons.append(recon) return torch.stack(recons).mean(0)
def test_reparam_log_joint(model, kwargs): guide = AutoIAFNormal(model) guide(**kwargs) neutra = NeuTraReparam(guide) reparam_model = neutra.reparam(model) _, pe_fn, transforms, _ = initialize_model(model, model_kwargs=kwargs) init_params, pe_fn_neutra, _, _ = initialize_model( reparam_model, model_kwargs=kwargs ) latent_x = list(init_params.values())[0] transformed_params = neutra.transform_sample(latent_x) pe_transformed = pe_fn_neutra(init_params) neutra_transform = ComposeTransform(guide.get_posterior(**kwargs).transforms) latent_y = neutra_transform(latent_x) log_det_jacobian = neutra_transform.log_abs_det_jacobian(latent_x, latent_y) pe = pe_fn({k: transforms[k](v) for k, v in transformed_params.items()}) assert_close(pe_transformed, pe - log_det_jacobian)
def _get_preprocess_transforms(self): alpha = 0.05 num_bits = 8 if self.preprocessing == 'glow': # Map to [-0.5,0.5] a1 = AffineTransform(-0.5, (1. / 2 ** num_bits)) preprocess_transform = ComposeTransform([a1]) elif self.preprocessing == 'realnvp': # Map to [0,1] a1 = AffineTransform(0., (1. / 2 ** num_bits)) # Map into unconstrained space as done in RealNVP a2 = AffineTransform(alpha, (1 - alpha)) s = SigmoidTransform() preprocess_transform = ComposeTransform([a1, a2, s.inv]) else: raise ValueError(f'{self.preprocessing} not valid.') return preprocess_transform
def model(n_samples=None, scale=2.): with pyro.plate('observations', n_samples): thickness = pyro.sample('thickness', Gamma(10., 5.)) loc = (thickness - 2.5) * 2 transforms = ComposeTransform([SigmoidTransform(), AffineTransform(10, 15)]) width = pyro.sample('width', TransformedDistribution(Normal(loc, scale), transforms)) return thickness, width
def model(self): thickness, intensity = self.pgm_model() x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) x_dist = TransformedDistribution( x_base_dist, ComposeTransform(self.x_transforms).inv) x = pyro.sample('x', x_dist) return x, thickness, intensity
def infer_x_base(self, thickness, intensity, x): x_base_dist = Normal(self.x_base_loc, self.x_base_scale) thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) intensity_ = self.intensity_flow_norm.inv(intensity) context = torch.cat([thickness_, intensity_], 1) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms) return cond_x_transforms(x)
def _get_transformed_x_dist(self, latent): x_pred_dist = self.decoder.predict(latent) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) preprocess_transform = self._get_preprocess_transforms() if isinstance(x_pred_dist, MultivariateNormal) or isinstance( x_pred_dist, LowRankMultivariateNormal): chol_transform = LowerCholeskyAffine(x_pred_dist.loc, x_pred_dist.scale_tril) reshape_transform = ReshapeTransform(self.img_shape, (np.prod(self.img_shape), )) x_reparam_transform = ComposeTransform( [reshape_transform, chol_transform, reshape_transform.inv]) elif isinstance(x_pred_dist, Independent): x_pred_dist = x_pred_dist.base_dist x_reparam_transform = AffineTransform(x_pred_dist.loc, x_pred_dist.scale, 3) return TransformedDistribution( x_base_dist, ComposeTransform([x_reparam_transform, preprocess_transform]))
def __init__(self, **kwargs): super().__init__(**kwargs) # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal self.intensity_flow_components = ComposeTransformModule( [LearnedAffineTransform(), Spline(1)]) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ]
def model(n_samples=None, scale=0.5, invert=False): with pyro.plate('observations', n_samples): thickness = 0.5 + pyro.sample('thickness', Gamma(10., 5.)) if invert: loc = (thickness - 2) * -2 else: loc = (thickness - 2.5) * 2 transforms = ComposeTransform( [SigmoidTransform(), AffineTransform(64, 191)]) intensity = pyro.sample( 'intensity', TransformedDistribution(Normal(loc, scale), transforms)) return thickness, intensity
def infer_exogeneous(self, **obs): # assuming that we use transformed distributions for everything: cond_sample = pyro.condition(self.sample, data=obs) cond_trace = pyro.poutine.trace(cond_sample).get_trace(obs['x'].shape[0]) output = {} for name, node in cond_trace.nodes.items(): if 'fn' not in node.keys(): continue fn = node['fn'] if isinstance(fn, Independent): fn = fn.base_dist if isinstance(fn, TransformedDistribution): output[name + '_base'] = ComposeTransform(fn.transforms).inv(node['value']) return output
def model(self): thickness, intensity = self.pgm_model() thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) intensity_ = self.intensity_flow_norm.inv(intensity) context = torch.cat([thickness_, intensity_], 1) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms).inv cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms) x = pyro.sample('x', cond_x_dist) return x, thickness, intensity
def infer_exogenous(self, obs): # assuming that we use transformed distributions for everything cond_sample = pyro.condition(self.sample, data=obs) batch_size = obs['x'].shape[0] cond_trace = pyro.poutine.trace(cond_sample).get_trace(batch_size) output = {} for name, node in cond_trace.nodes.items(): if 'z' in name or 'fn' not in node.keys(): continue fn = node['fn'] if isinstance(fn, Independent): fn = fn.base_dist if isinstance(fn, TransformedDistribution): # compute exogenous base distribution at all sites. base dist created with TransformReparam output[name + '_base'] = ComposeTransform(fn.transforms).inv(node['value']) return output
def model(self): age, sex, ventricle_volume, brain_volume = self.pgm_model() ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( ventricle_volume) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) age_ = self.age_flow_constraint_transforms.inv(age) z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1)) latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1) x_loc = self.decoder_mean(self.decoder(latent)) x_scale = torch.exp(self.decoder_logstd) theta = self.decoder_affine_param_net(latent).view(-1, 2, 3) with warnings.catch_warnings(): warnings.simplefilter("ignore") grid = nn.functional.affine_grid(theta, x_loc.size()) x_loc_deformed = nn.functional.grid_sample(x_loc, grid) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) preprocess_transform = self._get_preprocess_transforms() x_dist = TransformedDistribution( x_base_dist, ComposeTransform([ AffineTransform(x_loc_deformed, x_scale, 3), preprocess_transform ])) x = pyro.sample('x', x_dist) return x, z, age, sex, ventricle_volume, brain_volume
class FlowSCM(pyroModule): """ definition of FlowSCM class""" def __init__(self, use_affine_ex=True, **kwargs): super.__init__(**kwargs) self.num_scales = 2 self.register_buffer("glasses_base_loc", torch.zeros([ 1, ], requires_grad=False)) self.register_buffer("glasses_base_scale", torch.ones([ 1, ], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_loc", torch.zeros([], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_scale", torch.ones([], requires_grad=False)) self.glasses_flow_components = ComposeTransformModule([Spline(1)]) self.glasses_flow_constraint_transforms = ComposeTransform( [self.glasses_flow_lognorm, SigmoidTransform()]) self.glasses_flow_transforms = ComposeTransform([ self.glasses_flow_components, self.glasses_flow_constraint_transforms ]) glasses_base_dist = Normal(self.glasses_base_loc, self.glasses_base_scale).to_event(1) self.glasses_dist = TransformedDistribution( glasses_base_dist, self.glasses_flow_transforms) glasses_ = pyro.sample("glasses_", self.glasses_dist) glasses = pyro.sample("glasses", dist.Bernoulli(glasses_)) glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_) self.x_transforms = self._build_image_flow() self.register_buffer("x_base_loc", torch.zeros([1, 64, 64], requires_grad=False)) self.register_buffer("x_base_scale", torch.ones([1, 64, 64], requires_grad=False)) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms).inv cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms) x = pyro.sample("x", cond_x_dist) return x, glasses def _build_image_flow(self): self.trans_modules = ComposeTransformModule([]) self.x_transforms = [] self.x_transforms += [self._get_preprocess_transforms()] c = 1 for _ in range(self.num_scales): self.x_transforms.append(SqueezeTransform()) c *= 4 for _ in range(self.flows_per_scale): if self.use_actnorm: actnorm = ActNorm(c) self.trans_modules.append(actnorm) self.x_transforms.append(actnorm) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms.append( TransposeTransform(torch.tensor((1, 2, 0)))) ac = ConditionalAffineCoupling( c // 2, BasicFlowConvNet(c // 2, self.hidden_channels, (c // 2, c // 2), 2)) self.trans_modules.append(ac) self.x_transforms.append(ac) self.x_transforms.append( TransposeTransform(torch.tensor((2, 0, 1)))) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms += [ ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales, 32 // 2**self.num_scales), (1, 32, 32)) ] if self.use_affine_ex: affine_net = DenseNN(2, [16, 16], param_dims=[1, 1]) affine_trans = ConditionalAffineTransform(context_nn=affine_net, event_dim=3) self.trans_modules.append(affine_trans) self.x_transforms.append(affine_trans)
def infer_x_base(self, thickness, intensity, x): return ComposeTransform(self.x_transforms)(x)
class IndependentFlowSEM(BaseFlowSEM): def __init__(self, use_affine_ex: bool = True, **kwargs): super().__init__(**kwargs) self.use_affine_ex = use_affine_ex # decoder parts # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal self.intensity_flow_components = ComposeTransformModule( [LearnedAffineTransform(), Spline(1)]) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ] # realnvp or so for x self._build_image_flow() def _build_image_flow(self): self.trans_modules = ComposeTransformModule([]) self.x_transforms = [] self.x_transforms += [self._get_preprocess_transforms()] c = 1 for _ in range(self.num_scales): self.x_transforms.append(SqueezeTransform()) c *= 4 for _ in range(self.flows_per_scale): if self.use_actnorm: actnorm = ActNorm(c) self.trans_modules.append(actnorm) self.x_transforms.append(actnorm) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms.append( TransposeTransform(torch.tensor((1, 2, 0)))) ac = AffineCoupling( c // 2, BasicFlowConvNet(c // 2, self.hidden_channels, (c // 2, c // 2))) self.trans_modules.append(ac) self.x_transforms.append(ac) self.x_transforms.append( TransposeTransform(torch.tensor((2, 0, 1)))) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms += [ ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales, 32 // 2**self.num_scales), (1, 32, 32)) ] @pyro_method def pgm_model(self): thickness_base_dist = Normal(self.thickness_base_loc, self.thickness_base_scale).to_event(1) thickness_dist = TransformedDistribution( thickness_base_dist, self.thickness_flow_transforms) thickness = pyro.sample('thickness', thickness_dist) # pseudo call to thickness_flow_transforms to register with pyro _ = self.thickness_flow_components intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale).to_event(1) intensity_dist = TransformedDistribution( intensity_base_dist, self.intensity_flow_transforms) intensity = pyro.sample('intensity', intensity_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.intensity_flow_components return thickness, intensity @pyro_method def model(self): thickness, intensity = self.pgm_model() x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) x_dist = TransformedDistribution( x_base_dist, ComposeTransform(self.x_transforms).inv) x = pyro.sample('x', x_dist) return x, thickness, intensity @pyro_method def infer_thickness_base(self, thickness): return self.thickness_flow_transforms.inv(thickness) @pyro_method def infer_intensity_base(self, intensity): return self.intensity_flow_transforms.inv(intensity) @pyro_method def infer_x_base(self, thickness, intensity, x): return ComposeTransform(self.x_transforms)(x)
def __init__(self, latent_dim: int, prior_components: int = 1, posterior_components: int = 1, logstd_init: float = -5, enc_filters: Tuple[int] = (16, 32, 64, 128), dec_filters: Tuple[int] = (128, 64, 32, 16), num_convolutions: int = 3, use_upconv: bool = False, decoder_type: str = 'fixed_var', decoder_cov_rank: int = 10, img_shape: Tuple[int] = (128, 128), use_nvae=False, use_weight_norm=False, use_spectral_norm=False, laplace_likelihood=False, eps=0.1, n_prior_flows=3, n_posterior_flows=3, use_autoregressive=False, use_swish=False, use_spline=False, use_stable=False, pseudo3d=False, head_filters=(16, 16), **kwargs): super().__init__(**kwargs) self.encoder_shape = ((3, ) if pseudo3d else (1, )) + tuple(img_shape) self.decoder_shape = (head_filters[0], ) + tuple(img_shape) self.img_shape = (1, ) + tuple(img_shape) self.latent_dim = latent_dim self.prior_components = prior_components self.posterior_components = posterior_components self.logstd_init = logstd_init self.enc_filters = enc_filters self.dec_filters = dec_filters self.head_filters = head_filters self.num_convolutions = num_convolutions self.use_upconv = use_upconv self.decoder_type = decoder_type self.decoder_cov_rank = decoder_cov_rank self.use_nvae = use_nvae self.use_weight_norm = use_weight_norm self.use_spectral_norm = use_spectral_norm self.laplace_likelihood = laplace_likelihood self.eps = eps self.n_prior_flows = n_prior_flows self.n_posterior_flows = n_posterior_flows self.use_autoregressive = use_autoregressive self.use_spline = use_spline self.use_swish = use_swish self.use_stable = use_stable self.pseudo3d = pseudo3d self.annealing_factor = [ 1. ] # initialize here; will be changed during training self.n_levels = 0 # decoder parts if use_nvae: decoder = NDecoder(num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + self.context_dim, output_size=self.decoder_shape) else: decoder = Decoder( num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + self.context_dim, upconv=self.use_upconv, output_size=self.decoder_shape, use_weight_norm=self.use_weight_norm, use_spectral_norm=self.use_spectral_norm, ) self._create_decoder(decoder) # encoder parts if self.use_nvae: self.encoder = NEncoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim, input_size=self.encoder_shape) else: self.encoder = Encoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim, input_size=self.encoder_shape, use_weight_norm=self.use_weight_norm, use_spectral_norm=self.use_spectral_norm) nonlinearity = Swish() if self.use_swish else torch.nn.LeakyReLU(0.1) latent_layers = torch.nn.Sequential( torch.nn.Linear(self.latent_dim + self.context_dim, self.latent_dim), nonlinearity) if self.posterior_components > 1: self.latent_encoder = DeepIndepMixtureNormal( latent_layers, self.latent_dim, self.latent_dim, self.posterior_components) else: self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim, self.latent_dim) if self.prior_components > 1: self.z_loc = torch.nn.Parameter( torch.randn([self.prior_components, self.latent_dim])) self.z_scale = torch.nn.Parameter( torch.randn([self.latent_dim]).clamp(min=-1., max=None)) # log scale self.register_buffer( 'z_components', # don't be bayesian about the mixture components ((1 / self.prior_components) * torch.ones( [self.prior_components], requires_grad=False)).log()) else: self.register_buffer( 'z_loc', torch.zeros([ latent_dim, ], requires_grad=False)) self.register_buffer( 'z_scale', torch.ones([ latent_dim, ], requires_grad=False)) self.z_components = None # priors self.sex_logits = torch.nn.Parameter(torch.zeros([ 1, ])) self.register_buffer('slice_number_min', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer( 'slice_number_max', 241. * torch.ones([ 1, ], requires_grad=False) + 1.) for k in self.required_data - {'sex', 'x', 'slice_number'}: self.register_buffer(f'{k}_base_loc', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer(f'{k}_base_scale', torch.ones([ 1, ], requires_grad=False)) self.register_buffer('x_base_loc', torch.zeros(self.img_shape, requires_grad=False)) self.register_buffer('x_base_scale', torch.ones(self.img_shape, requires_grad=False)) for k in self.required_data - {'sex', 'x', 'slice_number'}: self.register_buffer(f'{k}_flow_lognorm_loc', torch.zeros([], requires_grad=False)) self.register_buffer(f'{k}_flow_lognorm_scale', torch.ones([], requires_grad=False)) perm = lambda: torch.randperm( self.latent_dim, dtype=torch.long, requires_grad=False) self.use_prior_flow = self.n_prior_flows > 0 self.use_prior_permutations = self.n_prior_flows > 1 if self.use_prior_permutations: for i in range(self.n_prior_flows): self.register_buffer(f'prior_flow_permutation_{i}', perm()) self.use_posterior_flow = self.n_posterior_flows > 0 self.use_posterior_permutations = self.n_posterior_flows > 1 if self.use_posterior_permutations: for i in range(self.n_posterior_flows): self.register_buffer(f'posterior_flow_permutation_{i}', perm()) # age flow self.age_flow_components = ComposeTransformModule([Spline(1)]) self.age_flow_lognorm = AffineTransform( loc=self.age_flow_lognorm_loc.item(), scale=self.age_flow_lognorm_scale.item()) self.age_flow_constraint_transforms = ComposeTransform( [self.age_flow_lognorm, ExpTransform()]) self.age_flow_transforms = ComposeTransform( [self.age_flow_components, self.age_flow_constraint_transforms]) # other flows shared components self.ventricle_volume_flow_lognorm = AffineTransform( loc=self.ventricle_volume_flow_lognorm_loc.item(), scale=self.ventricle_volume_flow_lognorm_scale.item( )) # noqa: E501 self.ventricle_volume_flow_constraint_transforms = ComposeTransform( [self.ventricle_volume_flow_lognorm, ExpTransform()]) self.brain_volume_flow_lognorm = AffineTransform( loc=self.brain_volume_flow_lognorm_loc.item(), scale=self.brain_volume_flow_lognorm_scale.item()) self.brain_volume_flow_constraint_transforms = ComposeTransform( [self.brain_volume_flow_lognorm, ExpTransform()]) self.lesion_volume_flow_lognorm = AffineTransform( loc=self.lesion_volume_flow_lognorm_loc.item(), scale=self.lesion_volume_flow_lognorm_scale.item()) self.lesion_volume_flow_eps = AffineTransform(loc=-eps, scale=1.) self.lesion_volume_flow_constraint_transforms = ComposeTransform([ self.lesion_volume_flow_lognorm, ExpTransform(), self.lesion_volume_flow_eps ]) self.duration_flow_lognorm = AffineTransform( loc=self.duration_flow_lognorm_loc.item(), scale=self.duration_flow_lognorm_scale.item()) self.duration_flow_eps = AffineTransform(loc=-eps, scale=1.) self.duration_flow_constraint_transforms = ComposeTransform([ self.duration_flow_lognorm, ExpTransform(), self.duration_flow_eps ]) self.edss_flow_lognorm = AffineTransform( loc=self.edss_flow_lognorm_loc.item(), scale=self.edss_flow_lognorm_scale.item()) self.edss_flow_eps = AffineTransform(loc=-eps, scale=1.) self.edss_flow_constraint_transforms = ComposeTransform( [self.edss_flow_lognorm, ExpTransform(), self.edss_flow_eps]) hidden_dims = (3 * self.latent_dim + 1, ) if self.use_autoregressive else (2 * self.latent_dim, 2 * self.latent_dim) flow_kwargs = dict(hidden_dims=hidden_dims, nonlinearity=nonlinearity) if self.use_spline: flow_ = spline_autoregressive if self.use_autoregressive else spline_coupling else: flow_ = affine_autoregressive if self.use_autoregressive else affine_coupling if self.use_autoregressive: flow_kwargs['stable'] = self.use_stable if self.use_prior_permutations: self.prior_affine = iterated( self.n_prior_flows, batchnorm, self.latent_dim, momentum=0.05) if self.use_prior_flow else [] self.prior_permutations = [ Permute(getattr(self, f'prior_flow_permutation_{i}')) for i in range(self.n_prior_flows) ] self.prior_flow_components = iterated( self.n_prior_flows, flow_, self.latent_dim, ** flow_kwargs) if self.use_prior_flow else [] self.prior_flow_transforms = [ x for c in zip(self.prior_permutations, self.prior_affine, self.prior_flow_components) for x in c ] else: self.prior_affine = [] self.prior_permutations = [] self.prior_flow_components = flow_( self.latent_dim, **flow_kwargs) if self.use_prior_flow else [] self.prior_flow_transforms = [self.prior_flow_components] if self.use_posterior_permutations: self.posterior_affine = iterated(self.n_posterior_flows, batchnorm, self.latent_dim, momentum=0.05) self.posterior_permutations = [ Permute(getattr(self, f'posterior_flow_permutation_{i}')) for i in range(self.n_posterior_flows) ] self.posterior_flow_components = iterated(self.n_posterior_flows, flow_, self.latent_dim, **flow_kwargs) self.posterior_flow_transforms = [ x for c in zip(self.posterior_permutations, self. posterior_affine, self.posterior_flow_components) for x in c ] else: self.posterior_affine = [] self.posterior_permutations = [] self.posterior_flow_components = flow_( self.latent_dim, ** flow_kwargs) if self.use_posterior_flow else [] self.posterior_flow_transforms = [self.posterior_flow_components]
class ConditionalDecoderVISEM(BaseVISEM): context_dim = 2 def __init__(self, **kwargs): super().__init__(**kwargs) # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal self.intensity_flow_components = ComposeTransformModule( [LearnedAffineTransform(), Spline(1)]) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ] @pyro_method def pgm_model(self): thickness_base_dist = Normal(self.thickness_base_loc, self.thickness_base_scale).to_event(1) thickness_dist = TransformedDistribution( thickness_base_dist, self.thickness_flow_transforms) thickness = pyro.sample('thickness', thickness_dist) # pseudo call to thickness_flow_transforms to register with pyro _ = self.thickness_flow_components intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale).to_event(1) intensity_dist = TransformedDistribution( intensity_base_dist, self.intensity_flow_transforms) intensity = pyro.sample('intensity', intensity_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.intensity_flow_components return thickness, intensity @pyro_method def model(self): thickness, intensity = self.pgm_model() thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) intensity_ = self.intensity_flow_norm.inv(intensity) z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1)) latent = torch.cat([z, thickness_, intensity_], 1) x_dist = self._get_transformed_x_dist(latent) x = pyro.sample('x', x_dist) return x, z, thickness, intensity @pyro_method def guide(self, x, thickness, intensity): with pyro.plate('observations', x.shape[0]): hidden = self.encoder(x) thickness_ = self.thickness_flow_constraint_transforms.inv( thickness) intensity_ = self.intensity_flow_norm.inv(intensity) hidden = torch.cat([hidden, thickness_, intensity_], 1) latent_dist = self.latent_encoder.predict(hidden) z = pyro.sample('z', latent_dist) return z @pyro_method def infer_thickness_base(self, thickness): return self.thickness_flow_transforms.inv(thickness) @pyro_method def infer_intensity_base(self, intensity): return self.intensity_flow_transforms.inv(intensity)
def __init__(self, latent_dim: int, logstd_init: float = -5, enc_filters: str = '16,32,64,128', dec_filters: str = '128,64,32,16', num_convolutions: int = 2, use_upconv: bool = False, decoder_type: str = 'fixed_var', decoder_cov_rank: int = 10, **kwargs): super().__init__(**kwargs) self.img_shape = (1, 192 // self.downsample, 192 // self.downsample) if self.downsample > 0 else (1, 192, 192) self.latent_dim = latent_dim self.logstd_init = logstd_init self.enc_filters = tuple( int(f.strip()) for f in enc_filters.split(',')) self.dec_filters = tuple( int(f.strip()) for f in dec_filters.split(',')) self.num_convolutions = num_convolutions self.use_upconv = use_upconv self.decoder_type = decoder_type self.decoder_cov_rank = decoder_cov_rank # decoder parts decoder = Decoder(num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + self.context_dim, upconv=self.use_upconv, output_size=self.img_shape) if self.decoder_type == 'fixed_var': self.decoder = Conv2dIndepNormal(decoder, 1, 1) torch.nn.init.zeros_(self.decoder.logvar_head.weight) self.decoder.logvar_head.weight.requires_grad = False torch.nn.init.constant_(self.decoder.logvar_head.bias, self.logstd_init) self.decoder.logvar_head.bias.requires_grad = False elif self.decoder_type == 'learned_var': self.decoder = Conv2dIndepNormal(decoder, 1, 1) torch.nn.init.zeros_(self.decoder.logvar_head.weight) self.decoder.logvar_head.weight.requires_grad = False torch.nn.init.constant_(self.decoder.logvar_head.bias, self.logstd_init) self.decoder.logvar_head.bias.requires_grad = True elif self.decoder_type == 'independent_gaussian': self.decoder = Conv2dIndepNormal(decoder, 1, 1) torch.nn.init.zeros_(self.decoder.logvar_head.weight) self.decoder.logvar_head.weight.requires_grad = True torch.nn.init.normal_(self.decoder.logvar_head.bias, self.logstd_init, 1e-1) self.decoder.logvar_head.bias.requires_grad = True elif self.decoder_type == 'multivariate_gaussian': seq = torch.nn.Sequential(decoder, Lambda(lambda x: x.view(x.shape[0], -1))) self.decoder = DeepMultivariateNormal(seq, np.prod(self.img_shape), np.prod(self.img_shape)) elif self.decoder_type == 'sharedvar_multivariate_gaussian': seq = torch.nn.Sequential(decoder, Lambda(lambda x: x.view(x.shape[0], -1))) self.decoder = DeepMultivariateNormal(seq, np.prod(self.img_shape), np.prod(self.img_shape)) torch.nn.init.zeros_(self.decoder.logdiag_head.weight) self.decoder.logdiag_head.weight.requires_grad = False torch.nn.init.zeros_(self.decoder.lower_head.weight) self.decoder.lower_head.weight.requires_grad = False torch.nn.init.normal_(self.decoder.logdiag_head.bias, self.logstd_init, 1e-1) self.decoder.logdiag_head.bias.requires_grad = True elif self.decoder_type == 'lowrank_multivariate_gaussian': seq = torch.nn.Sequential(decoder, Lambda(lambda x: x.view(x.shape[0], -1))) self.decoder = DeepLowRankMultivariateNormal( seq, np.prod(self.img_shape), np.prod(self.img_shape), decoder_cov_rank) elif self.decoder_type == 'sharedvar_lowrank_multivariate_gaussian': seq = torch.nn.Sequential(decoder, Lambda(lambda x: x.view(x.shape[0], -1))) self.decoder = DeepLowRankMultivariateNormal( seq, np.prod(self.img_shape), np.prod(self.img_shape), decoder_cov_rank) torch.nn.init.zeros_(self.decoder.logdiag_head.weight) self.decoder.logdiag_head.weight.requires_grad = False torch.nn.init.zeros_(self.decoder.factor_head.weight) self.decoder.factor_head.weight.requires_grad = False torch.nn.init.normal_(self.decoder.logdiag_head.bias, self.logstd_init, 1e-1) self.decoder.logdiag_head.bias.requires_grad = True else: raise ValueError('unknown ') # encoder parts self.encoder = Encoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim, input_size=self.img_shape) latent_layers = torch.nn.Sequential( torch.nn.Linear(self.latent_dim + self.context_dim, self.latent_dim), torch.nn.ReLU()) self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim, self.latent_dim) # priors self.register_buffer('age_base_loc', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer('age_base_scale', torch.ones([ 1, ], requires_grad=False)) self.sex_logits = torch.nn.Parameter(torch.zeros([ 1, ])) self.register_buffer('ventricle_volume_base_loc', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer('ventricle_volume_base_scale', torch.ones([ 1, ], requires_grad=False)) self.register_buffer('brain_volume_base_loc', torch.zeros([ 1, ], requires_grad=False)) self.register_buffer('brain_volume_base_scale', torch.ones([ 1, ], requires_grad=False)) self.register_buffer('z_loc', torch.zeros([ latent_dim, ], requires_grad=False)) self.register_buffer('z_scale', torch.ones([ latent_dim, ], requires_grad=False)) self.register_buffer('x_base_loc', torch.zeros(self.img_shape, requires_grad=False)) self.register_buffer('x_base_scale', torch.ones(self.img_shape, requires_grad=False)) self.register_buffer('age_flow_lognorm_loc', torch.zeros([], requires_grad=False)) self.register_buffer('age_flow_lognorm_scale', torch.ones([], requires_grad=False)) self.register_buffer('ventricle_volume_flow_lognorm_loc', torch.zeros([], requires_grad=False)) self.register_buffer('ventricle_volume_flow_lognorm_scale', torch.ones([], requires_grad=False)) self.register_buffer('brain_volume_flow_lognorm_loc', torch.zeros([], requires_grad=False)) self.register_buffer('brain_volume_flow_lognorm_scale', torch.ones([], requires_grad=False)) # age flow self.age_flow_components = ComposeTransformModule([Spline(1)]) self.age_flow_lognorm = AffineTransform( loc=self.age_flow_lognorm_loc.item(), scale=self.age_flow_lognorm_scale.item()) self.age_flow_constraint_transforms = ComposeTransform( [self.age_flow_lognorm, ExpTransform()]) self.age_flow_transforms = ComposeTransform( [self.age_flow_components, self.age_flow_constraint_transforms]) # other flows shared components self.ventricle_volume_flow_lognorm = AffineTransform( loc=self.ventricle_volume_flow_lognorm_loc.item(), scale=self.ventricle_volume_flow_lognorm_scale.item( )) # noqa: E501 self.ventricle_volume_flow_constraint_transforms = ComposeTransform( [self.ventricle_volume_flow_lognorm, ExpTransform()]) self.brain_volume_flow_lognorm = AffineTransform( loc=self.brain_volume_flow_lognorm_loc.item(), scale=self.brain_volume_flow_lognorm_scale.item()) self.brain_volume_flow_constraint_transforms = ComposeTransform( [self.brain_volume_flow_lognorm, ExpTransform()])
class ConditionalFlowSEM(BaseFlowSEM): def __init__(self, use_affine_ex: bool = True, **kwargs): super().__init__(**kwargs) self.use_affine_ex = use_affine_ex # decoder parts # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal intensity_net = DenseNN(1, [1], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.intensity_flow_components = ConditionalAffineTransform( context_nn=intensity_net, event_dim=0) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ] # build flow as s_affine_w * t * e_s + b -> depends on t though # realnvp or so for x self._build_image_flow() def _build_image_flow(self): self.trans_modules = ComposeTransformModule([]) self.x_transforms = [] self.x_transforms += [self._get_preprocess_transforms()] c = 1 for _ in range(self.num_scales): self.x_transforms.append(SqueezeTransform()) c *= 4 for _ in range(self.flows_per_scale): if self.use_actnorm: actnorm = ActNorm(c) self.trans_modules.append(actnorm) self.x_transforms.append(actnorm) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms.append( TransposeTransform(torch.tensor((1, 2, 0)))) ac = ConditionalAffineCoupling( c // 2, BasicFlowConvNet(c // 2, self.hidden_channels, (c // 2, c // 2), 2)) self.trans_modules.append(ac) self.x_transforms.append(ac) self.x_transforms.append( TransposeTransform(torch.tensor((2, 0, 1)))) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms += [ ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales, 32 // 2**self.num_scales), (1, 32, 32)) ] if self.use_affine_ex: affine_net = DenseNN(2, [16, 16], param_dims=[1, 1]) affine_trans = ConditionalAffineTransform(context_nn=affine_net, event_dim=3) self.trans_modules.append(affine_trans) self.x_transforms.append(affine_trans) @pyro_method def pgm_model(self): thickness_base_dist = Normal(self.thickness_base_loc, self.thickness_base_scale).to_event(1) thickness_dist = TransformedDistribution( thickness_base_dist, self.thickness_flow_transforms) thickness = pyro.sample('thickness', thickness_dist) thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) # pseudo call to thickness_flow_transforms to register with pyro _ = self.thickness_flow_components intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale).to_event(1) intensity_dist = ConditionalTransformedDistribution( intensity_base_dist, self.intensity_flow_transforms).condition(thickness_) intensity = pyro.sample('intensity', intensity_dist) # pseudo call to w_flow_transforms to register with pyro _ = self.intensity_flow_components return thickness, intensity @pyro_method def model(self): thickness, intensity = self.pgm_model() thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) intensity_ = self.intensity_flow_norm.inv(intensity) context = torch.cat([thickness_, intensity_], 1) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms).inv cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms) x = pyro.sample('x', cond_x_dist) return x, thickness, intensity @pyro_method def infer_thickness_base(self, thickness): return self.thickness_flow_transforms.inv(thickness) @pyro_method def infer_intensity_base(self, thickness, intensity): intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale) thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) cond_intensity_transforms = ComposeTransform( ConditionalTransformedDistribution( intensity_base_dist, self.intensity_flow_transforms).condition( thickness_).transforms) return cond_intensity_transforms.inv(intensity) @pyro_method def infer_x_base(self, thickness, intensity, x): x_base_dist = Normal(self.x_base_loc, self.x_base_scale) thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) intensity_ = self.intensity_flow_norm.inv(intensity) context = torch.cat([thickness_, intensity_], 1) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms) return cond_x_transforms(x) @classmethod def add_arguments(cls, parser): parser = super().add_arguments(parser) parser.add_argument( '--use_affine_ex', default=False, action='store_true', help= "whether to use conditional affine transformation on e_x (default: %(default)s)" ) return parser
class ConditionalSTNVISEM(BaseVISEM): def __init__(self, **kwargs): super().__init__(**kwargs) # decoder parts self.decoder = Decoder(num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + 3, upconv=self.use_upconv) self.decoder_mean = torch.nn.Conv2d(1, 1, 1) self.decoder_logstd = torch.nn.Parameter( torch.ones([]) * self.logstd_init) self.decoder_affine_param_net = nn.Sequential( nn.Linear(self.latent_dim + 3, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, 6)) self.decoder_affine_param_net[-1].weight.data.zero_() self.decoder_affine_param_net[-1].bias.data.copy_( torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # age flow self.age_flow_components = ComposeTransformModule([Spline(1)]) self.age_flow_lognorm = AffineTransform(loc=0., scale=1.) self.age_flow_constraint_transforms = ComposeTransform( [self.age_flow_lognorm, ExpTransform()]) self.age_flow_transforms = ComposeTransform( [self.age_flow_components, self.age_flow_constraint_transforms]) # ventricle_volume flow # TODO: decide on how many things to condition on ventricle_volume_net = DenseNN(2, [8, 16], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.ventricle_volume_flow_components = ConditionalAffineTransform( context_nn=ventricle_volume_net, event_dim=0) self.ventricle_volume_flow_lognorm = AffineTransform(loc=0., scale=1.) self.ventricle_volume_flow_constraint_transforms = ComposeTransform( [self.ventricle_volume_flow_lognorm, ExpTransform()]) self.ventricle_volume_flow_transforms = [ self.ventricle_volume_flow_components, self.ventricle_volume_flow_constraint_transforms ] # brain_volume flow # TODO: decide on how many things to condition on brain_volume_net = DenseNN(2, [8, 16], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.brain_volume_flow_components = ConditionalAffineTransform( context_nn=brain_volume_net, event_dim=0) self.brain_volume_flow_lognorm = AffineTransform(loc=0., scale=1.) self.brain_volume_flow_constraint_transforms = ComposeTransform( [self.brain_volume_flow_lognorm, ExpTransform()]) self.brain_volume_flow_transforms = [ self.brain_volume_flow_components, self.brain_volume_flow_constraint_transforms ] # encoder parts self.encoder = Encoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim) # TODO: do we need to replicate the PGM here to be able to run conterfactuals? oO latent_layers = torch.nn.Sequential( torch.nn.Linear(self.latent_dim + 3, self.latent_dim), torch.nn.ReLU()) self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim, self.latent_dim) @pyro_method def pgm_model(self): sex_dist = Bernoulli(self.sex_logits).to_event(1) sex = pyro.sample('sex', sex_dist) age_base_dist = Normal(self.age_base_loc, self.age_base_scale).to_event(1) age_dist = TransformedDistribution(age_base_dist, self.age_flow_transforms) age = pyro.sample('age', age_dist) age_ = self.age_flow_constraint_transforms.inv(age) # pseudo call to thickness_flow_transforms to register with pyro _ = self.age_flow_transforms context = torch.cat([sex, age_], 1) ventricle_volume_base_dist = Normal( self.ventricle_volume_base_loc, self.ventricle_volume_base_scale).to_event(1) ventricle_volume_dist = ConditionalTransformedDistribution( ventricle_volume_base_dist, self.ventricle_volume_flow_transforms).condition(context) ventricle_volume = pyro.sample('ventricle_volume', ventricle_volume_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.ventricle_volume_flow_transforms brain_volume_base_dist = Normal( self.brain_volume_base_loc, self.brain_volume_base_scale).to_event(1) brain_volume_dist = ConditionalTransformedDistribution( brain_volume_base_dist, self.brain_volume_flow_transforms).condition(context) brain_volume = pyro.sample('brain_volume', brain_volume_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.brain_volume_flow_transforms return age, sex, ventricle_volume, brain_volume @pyro_method def model(self): age, sex, ventricle_volume, brain_volume = self.pgm_model() ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( ventricle_volume) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) age_ = self.age_flow_constraint_transforms.inv(age) z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1)) latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1) x_loc = self.decoder_mean(self.decoder(latent)) x_scale = torch.exp(self.decoder_logstd) theta = self.decoder_affine_param_net(latent).view(-1, 2, 3) with warnings.catch_warnings(): warnings.simplefilter("ignore") grid = nn.functional.affine_grid(theta, x_loc.size()) x_loc_deformed = nn.functional.grid_sample(x_loc, grid) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) preprocess_transform = self._get_preprocess_transforms() x_dist = TransformedDistribution( x_base_dist, ComposeTransform([ AffineTransform(x_loc_deformed, x_scale, 3), preprocess_transform ])) x = pyro.sample('x', x_dist) return x, z, age, sex, ventricle_volume, brain_volume @pyro_method def guide(self, x, age, sex, ventricle_volume, brain_volume): with pyro.plate('observations', x.shape[0]): hidden = self.encoder(x) ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( ventricle_volume) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) age_ = self.age_flow_constraint_transforms.inv(age) hidden = torch.cat( [hidden, age_, ventricle_volume_, brain_volume_], 1) latent_dist = self.latent_encoder.predict(hidden) z = pyro.sample('z', latent_dist) return z
def __init__(self, **kwargs): super().__init__(**kwargs) # decoder parts self.decoder = Decoder(num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + 3, upconv=self.use_upconv) self.decoder_mean = torch.nn.Conv2d(1, 1, 1) self.decoder_logstd = torch.nn.Parameter( torch.ones([]) * self.logstd_init) self.decoder_affine_param_net = nn.Sequential( nn.Linear(self.latent_dim + 3, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, 6)) self.decoder_affine_param_net[-1].weight.data.zero_() self.decoder_affine_param_net[-1].bias.data.copy_( torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # age flow self.age_flow_components = ComposeTransformModule([Spline(1)]) self.age_flow_lognorm = AffineTransform(loc=0., scale=1.) self.age_flow_constraint_transforms = ComposeTransform( [self.age_flow_lognorm, ExpTransform()]) self.age_flow_transforms = ComposeTransform( [self.age_flow_components, self.age_flow_constraint_transforms]) # ventricle_volume flow # TODO: decide on how many things to condition on ventricle_volume_net = DenseNN(2, [8, 16], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.ventricle_volume_flow_components = ConditionalAffineTransform( context_nn=ventricle_volume_net, event_dim=0) self.ventricle_volume_flow_lognorm = AffineTransform(loc=0., scale=1.) self.ventricle_volume_flow_constraint_transforms = ComposeTransform( [self.ventricle_volume_flow_lognorm, ExpTransform()]) self.ventricle_volume_flow_transforms = [ self.ventricle_volume_flow_components, self.ventricle_volume_flow_constraint_transforms ] # brain_volume flow # TODO: decide on how many things to condition on brain_volume_net = DenseNN(2, [8, 16], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.brain_volume_flow_components = ConditionalAffineTransform( context_nn=brain_volume_net, event_dim=0) self.brain_volume_flow_lognorm = AffineTransform(loc=0., scale=1.) self.brain_volume_flow_constraint_transforms = ComposeTransform( [self.brain_volume_flow_lognorm, ExpTransform()]) self.brain_volume_flow_transforms = [ self.brain_volume_flow_components, self.brain_volume_flow_constraint_transforms ] # encoder parts self.encoder = Encoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim) # TODO: do we need to replicate the PGM here to be able to run conterfactuals? oO latent_layers = torch.nn.Sequential( torch.nn.Linear(self.latent_dim + 3, self.latent_dim), torch.nn.ReLU()) self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim, self.latent_dim)