def model(self, noise):
        N_SARS_COV2 = sample('N_SARS_COV2', Normal(noise['N_SARS_COV2'][0], noise['N_SARS_COV2'][1]))
        N_TOCI = sample('N_TOCI', Normal(noise['N_TOCI'][0], noise['N_TOCI'][1]))
        N_ACE2 = sample('N_ACE2', Normal(noise['N_ACE2'][0], noise['N_ACE2'][1]))
        N_PRR = sample('N_PRR', Normal(noise['N_PRR'][0], noise['N_PRR'][1]))
        N_AngII = sample('N_AngII', Normal(noise['N_AngII'][0], noise['N_AngII'][1]))
        N_AGTR1 = sample('N_AGTR1', Normal(noise['N_AGTR1'][0], noise['N_AGTR1'][1]))
        N_ADAM17 = sample('N_ADAM17', Normal(noise['N_ADAM17'][0], noise['N_ADAM17'][1]))
        N_sIL_6_alpha = sample('N_sIL_6_alpha', Normal(noise['N_sIL_6_alpha'][0], noise['N_sIL_6_alpha'][1]))
        N_TNF = sample('N_TNF', Normal(noise['N_TNF'][0], noise['N_TNF'][1]))
        N_EGF = sample('N_EGF', Normal(noise['N_EGF'][0], noise['N_EGF'][1]))
        N_EGFR = sample('N_EGFR', Normal(noise['N_EGFR'][0], noise['N_EGFR'][1]))
        N_IL6_STAT3 = sample('N_IL6_STAT3', Normal(noise['N_IL6_STAT3'][0], noise['N_IL6_STAT3'][1]))
        N_NF_xB = sample('N_NF_xB', Normal(noise['N_NF_xB'][0], noise['N_NF_xB'][1]))
        N_IL_6_AMP = sample('N_IL_6_AMP', Normal(noise['N_IL_6_AMP'][0], noise['N_IL_6_AMP'][1]))
        N_cytokine = sample('N_cytokine', Normal(noise['N_cytokine'][0], noise['N_cytokine'][1]))

        SARS_COV2 = sample('SARS_COV2', Normal(self.f_SARS_COV2(50, 10, N_SARS_COV2), 1.0))
        TOCI = sample('TOCI', Normal(self.f_TOCI(50, 10, N_TOCI), 1.0))

        PRR = sample('PRR', Delta(self.f_PRR(SARS_COV2, N_PRR)))
        ACE2 = sample('ACE2', Delta(self.f_ACE2(SARS_COV2, N_ACE2)))
        AngII = sample('AngII', Delta(self.f_AngII(ACE2, N_AngII)))
        AGTR1 = sample('AGTR1', Delta(self.f_AGTR1(AngII, N_AGTR1)))
        ADAM17 = sample('ADAM1Spike7', Delta(self.f_ADAM17(AGTR1, N_ADAM17)))
        TNF = sample('TNF', Delta(self.f_TNF(ADAM17, N_TNF)))
        sIL_6_alpha = sample('sIL_6_alpha', Delta(self.f_sIL_6_alpha(ADAM17, TOCI, N_sIL_6_alpha)))
        EGF = sample('EGF', Delta(self.f_EGF(ADAM17, N_EGF)))
        EGFR = sample('EGFR', Delta(self.f_EGFR(EGF, N_EGFR)))
        NF_xB = sample('NF_xB', Delta(self.f_NF_xB(PRR, EGFR, TNF, N_NF_xB)))
        IL6_STAT3 = sample('IL6_STAT3', Delta(self.f_IL6_STAT3(sIL_6_alpha, N_IL6_STAT3)))
        IL_6_AMP = sample('IL_6_AMP', Delta(self.f_IL_6_AMP(NF_xB, IL6_STAT3, N_IL_6_AMP)))
        cytokine = sample('cytokine', Delta(self.f_cytokine(IL_6_AMP, N_cytokine)))

        noise_samples = (
            N_PRR,
            N_ACE2,
            N_AngII,
            N_AGTR1,
            N_ADAM17,
            N_TNF,
            N_sIL_6_alpha,
            N_EGF,
            N_EGFR,
            N_NF_xB,
            N_IL6_STAT3,
            N_IL_6_AMP,
            N_cytokine,
        )

        if self.noise_type == NOISE_TYPE_OBSERVATIONAL:
            # Use the dictionary structure for generating observational dataset
            samples = {
                'a(SARS_COV2)': SARS_COV2.numpy(),
                'a(PRR)': PRR.numpy(),
                'a(ACE2)': ACE2.numpy(),
                'a(AngII)': AngII.numpy(),
                'a(AGTR1)': AGTR1.numpy(),
                'a(ADAM17)': ADAM17.numpy(),
                'a(TOCI)': TOCI.numpy(),
                'a(TNF)': TNF.numpy(),
                'a(sIL_6_alpha)': sIL_6_alpha.numpy(),
                'a(EGF)': EGF.numpy(),
                'a(EGFR)': EGFR.numpy(),
                'a(IL6_STAT3)': IL6_STAT3.numpy(),
                'a(NF_xB)': NF_xB.numpy(),
                'a(IL6_AMP)': IL_6_AMP.numpy(),
                'a(cytokine)': cytokine.numpy(),
            }
        elif self.noise_type == NOISE_TYPE_SAMPLES:
            ## Use the variable list for generating samples for causal effect
            samples = (
                SARS_COV2,
                PRR,
                ACE2,
                AngII,
                AGTR1,
                ADAM17,
                TOCI,
                TNF,
                sIL_6_alpha,
                EGF,
                EGFR,
                NF_xB,
                IL6_STAT3,
                IL_6_AMP,
                cytokine,
            )
        else:
            raise InvalidNoiseType(self.noise_type)

        return samples, noise_samples
Beispiel #2
0
def wrapped_model(x_data, y_data):
    pyro.sample("prediction", Delta(model(x_data, y_data)))
Beispiel #3
0
 def get_posterior(self, *args, **kwargs):
     svgd_particles = pyro.param("svgd_particles", self._init_loc)
     return Delta(svgd_particles, event_dim=1)