def infer_intensity_base(self, thickness, intensity): intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale) cond_intensity_transforms = ComposeTransform( ConditionalTransformedDistribution( intensity_base_dist, self.intensity_flow_transforms).condition( thickness).transforms) return cond_intensity_transforms.inv(intensity)
class ConditionalDecoderVISEM(BaseVISEM): context_dim = 2 def __init__(self, **kwargs): super().__init__(**kwargs) # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal self.intensity_flow_components = ComposeTransformModule( [LearnedAffineTransform(), Spline(1)]) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ] @pyro_method def pgm_model(self): thickness_base_dist = Normal(self.thickness_base_loc, self.thickness_base_scale).to_event(1) thickness_dist = TransformedDistribution( thickness_base_dist, self.thickness_flow_transforms) thickness = pyro.sample('thickness', thickness_dist) # pseudo call to thickness_flow_transforms to register with pyro _ = self.thickness_flow_components intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale).to_event(1) intensity_dist = TransformedDistribution( intensity_base_dist, self.intensity_flow_transforms) intensity = pyro.sample('intensity', intensity_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.intensity_flow_components return thickness, intensity @pyro_method def model(self): thickness, intensity = self.pgm_model() thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) intensity_ = self.intensity_flow_norm.inv(intensity) z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1)) latent = torch.cat([z, thickness_, intensity_], 1) x_dist = self._get_transformed_x_dist(latent) x = pyro.sample('x', x_dist) return x, z, thickness, intensity @pyro_method def guide(self, x, thickness, intensity): with pyro.plate('observations', x.shape[0]): hidden = self.encoder(x) thickness_ = self.thickness_flow_constraint_transforms.inv( thickness) intensity_ = self.intensity_flow_norm.inv(intensity) hidden = torch.cat([hidden, thickness_, intensity_], 1) latent_dist = self.latent_encoder.predict(hidden) z = pyro.sample('z', latent_dist) return z @pyro_method def infer_thickness_base(self, thickness): return self.thickness_flow_transforms.inv(thickness) @pyro_method def infer_intensity_base(self, intensity): return self.intensity_flow_transforms.inv(intensity)
class ConditionalFlowSEM(BaseFlowSEM): def __init__(self, use_affine_ex: bool = True, **kwargs): super().__init__(**kwargs) self.use_affine_ex = use_affine_ex # decoder parts # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal intensity_net = DenseNN(1, [1], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.intensity_flow_components = ConditionalAffineTransform( context_nn=intensity_net, event_dim=0) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ] # build flow as s_affine_w * t * e_s + b -> depends on t though # realnvp or so for x self._build_image_flow() def _build_image_flow(self): self.trans_modules = ComposeTransformModule([]) self.x_transforms = [] self.x_transforms += [self._get_preprocess_transforms()] c = 1 for _ in range(self.num_scales): self.x_transforms.append(SqueezeTransform()) c *= 4 for _ in range(self.flows_per_scale): if self.use_actnorm: actnorm = ActNorm(c) self.trans_modules.append(actnorm) self.x_transforms.append(actnorm) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms.append( TransposeTransform(torch.tensor((1, 2, 0)))) ac = ConditionalAffineCoupling( c // 2, BasicFlowConvNet(c // 2, self.hidden_channels, (c // 2, c // 2), 2)) self.trans_modules.append(ac) self.x_transforms.append(ac) self.x_transforms.append( TransposeTransform(torch.tensor((2, 0, 1)))) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms += [ ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales, 32 // 2**self.num_scales), (1, 32, 32)) ] if self.use_affine_ex: affine_net = DenseNN(2, [16, 16], param_dims=[1, 1]) affine_trans = ConditionalAffineTransform(context_nn=affine_net, event_dim=3) self.trans_modules.append(affine_trans) self.x_transforms.append(affine_trans) @pyro_method def pgm_model(self): thickness_base_dist = Normal(self.thickness_base_loc, self.thickness_base_scale).to_event(1) thickness_dist = TransformedDistribution( thickness_base_dist, self.thickness_flow_transforms) thickness = pyro.sample('thickness', thickness_dist) thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) # pseudo call to thickness_flow_transforms to register with pyro _ = self.thickness_flow_components intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale).to_event(1) intensity_dist = ConditionalTransformedDistribution( intensity_base_dist, self.intensity_flow_transforms).condition(thickness_) intensity = pyro.sample('intensity', intensity_dist) # pseudo call to w_flow_transforms to register with pyro _ = self.intensity_flow_components return thickness, intensity @pyro_method def model(self): thickness, intensity = self.pgm_model() thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) intensity_ = self.intensity_flow_norm.inv(intensity) context = torch.cat([thickness_, intensity_], 1) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms).inv cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms) x = pyro.sample('x', cond_x_dist) return x, thickness, intensity @pyro_method def infer_thickness_base(self, thickness): return self.thickness_flow_transforms.inv(thickness) @pyro_method def infer_intensity_base(self, thickness, intensity): intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale) thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) cond_intensity_transforms = ComposeTransform( ConditionalTransformedDistribution( intensity_base_dist, self.intensity_flow_transforms).condition( thickness_).transforms) return cond_intensity_transforms.inv(intensity) @pyro_method def infer_x_base(self, thickness, intensity, x): x_base_dist = Normal(self.x_base_loc, self.x_base_scale) thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) intensity_ = self.intensity_flow_norm.inv(intensity) context = torch.cat([thickness_, intensity_], 1) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms) return cond_x_transforms(x) @classmethod def add_arguments(cls, parser): parser = super().add_arguments(parser) parser.add_argument( '--use_affine_ex', default=False, action='store_true', help= "whether to use conditional affine transformation on e_x (default: %(default)s)" ) return parser
class ConditionalSTNVISEM(BaseVISEM): def __init__(self, **kwargs): super().__init__(**kwargs) # decoder parts self.decoder = Decoder(num_convolutions=self.num_convolutions, filters=self.dec_filters, latent_dim=self.latent_dim + 3, upconv=self.use_upconv) self.decoder_mean = torch.nn.Conv2d(1, 1, 1) self.decoder_logstd = torch.nn.Parameter( torch.ones([]) * self.logstd_init) self.decoder_affine_param_net = nn.Sequential( nn.Linear(self.latent_dim + 3, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, 6)) self.decoder_affine_param_net[-1].weight.data.zero_() self.decoder_affine_param_net[-1].bias.data.copy_( torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # age flow self.age_flow_components = ComposeTransformModule([Spline(1)]) self.age_flow_lognorm = AffineTransform(loc=0., scale=1.) self.age_flow_constraint_transforms = ComposeTransform( [self.age_flow_lognorm, ExpTransform()]) self.age_flow_transforms = ComposeTransform( [self.age_flow_components, self.age_flow_constraint_transforms]) # ventricle_volume flow # TODO: decide on how many things to condition on ventricle_volume_net = DenseNN(2, [8, 16], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.ventricle_volume_flow_components = ConditionalAffineTransform( context_nn=ventricle_volume_net, event_dim=0) self.ventricle_volume_flow_lognorm = AffineTransform(loc=0., scale=1.) self.ventricle_volume_flow_constraint_transforms = ComposeTransform( [self.ventricle_volume_flow_lognorm, ExpTransform()]) self.ventricle_volume_flow_transforms = [ self.ventricle_volume_flow_components, self.ventricle_volume_flow_constraint_transforms ] # brain_volume flow # TODO: decide on how many things to condition on brain_volume_net = DenseNN(2, [8, 16], param_dims=[1, 1], nonlinearity=torch.nn.Identity()) self.brain_volume_flow_components = ConditionalAffineTransform( context_nn=brain_volume_net, event_dim=0) self.brain_volume_flow_lognorm = AffineTransform(loc=0., scale=1.) self.brain_volume_flow_constraint_transforms = ComposeTransform( [self.brain_volume_flow_lognorm, ExpTransform()]) self.brain_volume_flow_transforms = [ self.brain_volume_flow_components, self.brain_volume_flow_constraint_transforms ] # encoder parts self.encoder = Encoder(num_convolutions=self.num_convolutions, filters=self.enc_filters, latent_dim=self.latent_dim) # TODO: do we need to replicate the PGM here to be able to run conterfactuals? oO latent_layers = torch.nn.Sequential( torch.nn.Linear(self.latent_dim + 3, self.latent_dim), torch.nn.ReLU()) self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim, self.latent_dim) @pyro_method def pgm_model(self): sex_dist = Bernoulli(self.sex_logits).to_event(1) sex = pyro.sample('sex', sex_dist) age_base_dist = Normal(self.age_base_loc, self.age_base_scale).to_event(1) age_dist = TransformedDistribution(age_base_dist, self.age_flow_transforms) age = pyro.sample('age', age_dist) age_ = self.age_flow_constraint_transforms.inv(age) # pseudo call to thickness_flow_transforms to register with pyro _ = self.age_flow_transforms context = torch.cat([sex, age_], 1) ventricle_volume_base_dist = Normal( self.ventricle_volume_base_loc, self.ventricle_volume_base_scale).to_event(1) ventricle_volume_dist = ConditionalTransformedDistribution( ventricle_volume_base_dist, self.ventricle_volume_flow_transforms).condition(context) ventricle_volume = pyro.sample('ventricle_volume', ventricle_volume_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.ventricle_volume_flow_transforms brain_volume_base_dist = Normal( self.brain_volume_base_loc, self.brain_volume_base_scale).to_event(1) brain_volume_dist = ConditionalTransformedDistribution( brain_volume_base_dist, self.brain_volume_flow_transforms).condition(context) brain_volume = pyro.sample('brain_volume', brain_volume_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.brain_volume_flow_transforms return age, sex, ventricle_volume, brain_volume @pyro_method def model(self): age, sex, ventricle_volume, brain_volume = self.pgm_model() ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( ventricle_volume) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) age_ = self.age_flow_constraint_transforms.inv(age) z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1)) latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1) x_loc = self.decoder_mean(self.decoder(latent)) x_scale = torch.exp(self.decoder_logstd) theta = self.decoder_affine_param_net(latent).view(-1, 2, 3) with warnings.catch_warnings(): warnings.simplefilter("ignore") grid = nn.functional.affine_grid(theta, x_loc.size()) x_loc_deformed = nn.functional.grid_sample(x_loc, grid) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) preprocess_transform = self._get_preprocess_transforms() x_dist = TransformedDistribution( x_base_dist, ComposeTransform([ AffineTransform(x_loc_deformed, x_scale, 3), preprocess_transform ])) x = pyro.sample('x', x_dist) return x, z, age, sex, ventricle_volume, brain_volume @pyro_method def guide(self, x, age, sex, ventricle_volume, brain_volume): with pyro.plate('observations', x.shape[0]): hidden = self.encoder(x) ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( ventricle_volume) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) age_ = self.age_flow_constraint_transforms.inv(age) hidden = torch.cat( [hidden, age_, ventricle_volume_, brain_volume_], 1) latent_dist = self.latent_encoder.predict(hidden) z = pyro.sample('z', latent_dist) return z
class IndependentFlowSEM(BaseFlowSEM): def __init__(self, use_affine_ex: bool = True, **kwargs): super().__init__(**kwargs) self.use_affine_ex = use_affine_ex # decoder parts # Flow for modelling t Gamma self.thickness_flow_components = ComposeTransformModule([Spline(1)]) self.thickness_flow_constraint_transforms = ComposeTransform( [self.thickness_flow_lognorm, ExpTransform()]) self.thickness_flow_transforms = ComposeTransform([ self.thickness_flow_components, self.thickness_flow_constraint_transforms ]) # affine flow for s normal self.intensity_flow_components = ComposeTransformModule( [LearnedAffineTransform(), Spline(1)]) self.intensity_flow_constraint_transforms = ComposeTransform( [SigmoidTransform(), self.intensity_flow_norm]) self.intensity_flow_transforms = [ self.intensity_flow_components, self.intensity_flow_constraint_transforms ] # realnvp or so for x self._build_image_flow() def _build_image_flow(self): self.trans_modules = ComposeTransformModule([]) self.x_transforms = [] self.x_transforms += [self._get_preprocess_transforms()] c = 1 for _ in range(self.num_scales): self.x_transforms.append(SqueezeTransform()) c *= 4 for _ in range(self.flows_per_scale): if self.use_actnorm: actnorm = ActNorm(c) self.trans_modules.append(actnorm) self.x_transforms.append(actnorm) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms.append( TransposeTransform(torch.tensor((1, 2, 0)))) ac = AffineCoupling( c // 2, BasicFlowConvNet(c // 2, self.hidden_channels, (c // 2, c // 2))) self.trans_modules.append(ac) self.x_transforms.append(ac) self.x_transforms.append( TransposeTransform(torch.tensor((2, 0, 1)))) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms += [ ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales, 32 // 2**self.num_scales), (1, 32, 32)) ] @pyro_method def pgm_model(self): thickness_base_dist = Normal(self.thickness_base_loc, self.thickness_base_scale).to_event(1) thickness_dist = TransformedDistribution( thickness_base_dist, self.thickness_flow_transforms) thickness = pyro.sample('thickness', thickness_dist) # pseudo call to thickness_flow_transforms to register with pyro _ = self.thickness_flow_components intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale).to_event(1) intensity_dist = TransformedDistribution( intensity_base_dist, self.intensity_flow_transforms) intensity = pyro.sample('intensity', intensity_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.intensity_flow_components return thickness, intensity @pyro_method def model(self): thickness, intensity = self.pgm_model() x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) x_dist = TransformedDistribution( x_base_dist, ComposeTransform(self.x_transforms).inv) x = pyro.sample('x', x_dist) return x, thickness, intensity @pyro_method def infer_thickness_base(self, thickness): return self.thickness_flow_transforms.inv(thickness) @pyro_method def infer_intensity_base(self, intensity): return self.intensity_flow_transforms.inv(intensity) @pyro_method def infer_x_base(self, thickness, intensity, x): return ComposeTransform(self.x_transforms)(x)
class FlowSCM(pyroModule): """ definition of FlowSCM class""" def __init__(self, use_affine_ex=True, **kwargs): super.__init__(**kwargs) self.num_scales = 2 self.register_buffer("glasses_base_loc", torch.zeros([ 1, ], requires_grad=False)) self.register_buffer("glasses_base_scale", torch.ones([ 1, ], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_loc", torch.zeros([], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_scale", torch.ones([], requires_grad=False)) self.glasses_flow_components = ComposeTransformModule([Spline(1)]) self.glasses_flow_constraint_transforms = ComposeTransform( [self.glasses_flow_lognorm, SigmoidTransform()]) self.glasses_flow_transforms = ComposeTransform([ self.glasses_flow_components, self.glasses_flow_constraint_transforms ]) glasses_base_dist = Normal(self.glasses_base_loc, self.glasses_base_scale).to_event(1) self.glasses_dist = TransformedDistribution( glasses_base_dist, self.glasses_flow_transforms) glasses_ = pyro.sample("glasses_", self.glasses_dist) glasses = pyro.sample("glasses", dist.Bernoulli(glasses_)) glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_) self.x_transforms = self._build_image_flow() self.register_buffer("x_base_loc", torch.zeros([1, 64, 64], requires_grad=False)) self.register_buffer("x_base_scale", torch.ones([1, 64, 64], requires_grad=False)) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms).inv cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms) x = pyro.sample("x", cond_x_dist) return x, glasses def _build_image_flow(self): self.trans_modules = ComposeTransformModule([]) self.x_transforms = [] self.x_transforms += [self._get_preprocess_transforms()] c = 1 for _ in range(self.num_scales): self.x_transforms.append(SqueezeTransform()) c *= 4 for _ in range(self.flows_per_scale): if self.use_actnorm: actnorm = ActNorm(c) self.trans_modules.append(actnorm) self.x_transforms.append(actnorm) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms.append( TransposeTransform(torch.tensor((1, 2, 0)))) ac = ConditionalAffineCoupling( c // 2, BasicFlowConvNet(c // 2, self.hidden_channels, (c // 2, c // 2), 2)) self.trans_modules.append(ac) self.x_transforms.append(ac) self.x_transforms.append( TransposeTransform(torch.tensor((2, 0, 1)))) gcp = GeneralizedChannelPermute(channels=c) self.trans_modules.append(gcp) self.x_transforms.append(gcp) self.x_transforms += [ ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales, 32 // 2**self.num_scales), (1, 32, 32)) ] if self.use_affine_ex: affine_net = DenseNN(2, [16, 16], param_dims=[1, 1]) affine_trans = ConditionalAffineTransform(context_nn=affine_net, event_dim=3) self.trans_modules.append(affine_trans) self.x_transforms.append(affine_trans)