示例#1
0
def plate_custom_model(subsample):
    with pyro.plate('plate', 20, subsample=subsample) as batch:
        result = batch
    return result
示例#2
0
def predict(args, data, samples, truth=None):
    logging.info("Forecasting {} steps ahead...".format(args.forecast))
    particle_plate = pyro.plate("particles", args.num_samples, dim=-1)

    # First we sample discrete auxiliary variables from the continuous
    # variables sampled in vectorized_model. This samples only time steps
    # [0:duration]. Here infer_discrete runs a forward-filter backward-sample
    # algorithm. We'll add these new samples to the existing dict of samples.
    model = poutine.condition(continuous_model, samples)
    model = particle_plate(model)
    model = infer_discrete(model, first_available_dim=-2)
    with poutine.trace() as tr:
        model(args, data)
    samples = OrderedDict((name, site["value"])
                          for name, site in tr.trace.nodes.items()
                          if site["type"] == "sample")

    # Next we'll run the forward generative process in discrete_model. This
    # samples time steps [duration:duration+forecast]. Again we'll update the
    # dict of samples.
    extended_data = list(data) + [None] * args.forecast
    model = poutine.condition(discrete_model, samples)
    model = particle_plate(model)
    with poutine.trace() as tr:
        model(args, extended_data)
    samples = OrderedDict((name, site["value"])
                          for name, site in tr.trace.nodes.items()
                          if site["type"] == "sample")

    # Finally we'll concatenate the sequentially sampled values into contiguous
    # tensors. This operates on the entire time interval [0:duration+forecast].
    for key in ("S", "I", "S2I", "I2R"):
        pattern = key + "_[0-9]+"
        series = [
            value for name, value in samples.items()
            if re.match(pattern, name)
        ]
        assert len(series) == args.duration + args.forecast
        series[0] = series[0].expand(series[1].shape)
        samples[key] = torch.stack(series, dim=-1)
    S2I = samples["S2I"]
    median = S2I.median(dim=0).values
    logging.info(
        "Median prediction of new infections (starting on day 0):\n{}".format(
            " ".join(map(str, map(int, median)))))

    # Optionally plot the latent and forecasted series of new infections.
    if args.plot:
        import matplotlib.pyplot as plt
        plt.figure()
        time = torch.arange(args.duration + args.forecast)
        p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)),
                           dim=0).values
        p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)),
                           dim=0).values
        plt.fill_between(time,
                         p05,
                         p95,
                         color="red",
                         alpha=0.3,
                         label="90% CI")
        plt.plot(time, median, "r-", label="median")
        plt.plot(time[:args.duration], data, "k.", label="observed")
        if truth is not None:
            plt.plot(time, truth, "k--", label="truth")
        plt.axvline(args.duration - 0.5, color="gray", lw=1)
        plt.xlim(0, len(time) - 1)
        plt.ylim(0, None)
        plt.xlabel("day after first infection")
        plt.ylabel("new infections per day")
        plt.title("New infections in population of {}".format(args.population))
        plt.legend(loc="upper left")
        plt.tight_layout()

    return samples
示例#3
0
 def model(subsample):
     with pyro.plate("data", len(data), subsample_size, subsample) as ind:
         x = data[ind]
         z = pyro.sample("z", Normal(0, 1))
         pyro.sample("x", Normal(z, 1), obs=x)
示例#4
0
def guide(data):
    data = torch.reshape(data, [60000, 50, 50])

    pyro.module('rnn', rnn)
    pyro.module('bl_rnn', bl_rnn)
    pyro.module('predict_l1', predict_l1)
    pyro.module('predict_l2', predict_l2)
    pyro.module('encode_l1', encode_l1)
    pyro.module('encode_l2', encode_l2)
    pyro.module('bl_predict_l1', bl_predict_l1)
    pyro.module('bl_predict_l2', bl_predict_l2)

    pyro.param('h_init', h_init)
    pyro.param('c_init', c_init)
    pyro.param('z_where_init', z_where_init)
    pyro.param('z_what_init', z_what_init)
    pyro.param('bl_h_init', bl_h_init)
    pyro.param('bl_c_init', bl_c_init)

    with pyro.plate('data', 60000, 64) as ix:
        # size = [64, 50, 50]
        batch = data[ix]

        flattened_batch = torch.Tensor.view(batch, [64, 2500])
        # inputs_raw = batch
        # inputs_embed = flattened_batch
        # inputs_bl_embed = flattened_batch

        state_h = torch.Tensor.expand(h_init, [64, 256])
        state_c = torch.Tensor.expand(c_init, [64, 256])
        state_bl_h = torch.Tensor.expand(bl_h_init, [64, 256])
        state_bl_c = torch.Tensor.expand(bl_c_init, [64, 256])
        state_z_pres = torch.ones(64, 1)
        state_z_where = torch.Tensor.expand(z_where_init, [64, 3])
        state_z_what = torch.Tensor.expand(z_what_init, [64, 50])

        z_pres = []
        z_where = []

        for t in range(3):
            #=========== guide_step
            # prev_h = state_h
            # prev_c = state_c
            # prev_bl_h = state_bl_h
            # prev_bl_c = state_bl_c
            # prev_z_pres = state_z_pres
            # prev_z_where = state_z_where
            # prev_z_what = state_z_what

            # size = [64, 2554]
            rnn_input = torch.cat(
                (flattened_batch, state_z_where, state_z_what, state_z_pres),
                1)
            # size = [64, 256], [64, 256]
            state_h, state_c = rnn(rnn_input, (state_h, state_c))

            #===== predict
            # size = [64, 7]
            out = predict_l2(F.relu(predict_l1(state_h)))
            # size = [64, 1]
            z_pres_p = torch.sigmoid(out[:, 0:1])
            # size = [64, 3]
            z_where_loc = out[:, 1:4]
            # size = [64, 3]
            z_where_scale = F.softplus(out[:, 4:])
            #===== predict

            #===== baseline_step
            # size = [64, 2554]
            rnn_input = torch.cat(
                (flattened_batch, torch.Tensor.detach(state_z_where),
                 torch.Tensor.detach(state_z_what),
                 torch.Tensor.detach(state_z_pres)), 1)
            # size = [64, 256], [64, 256]
            state_bl_h, state_bl_c = bl_rnn(rnn_input,
                                            (state_bl_h, state_bl_c))

            #===== bl_predict
            # size = [64, 1]
            bl_value = bl_predict_l2(F.relu(bl_predict_l1(state_bl_h)))
            #===== bl_predict

            bl_value = bl_value * state_z_pres
            infer_dict = dict(baseline=dict(
                baseline_value=torch.squeeze(bl_value, -1)))
            #===== baseline_step

            # size = [64, 1]
            cur_z_pres =\
                pyro.sample('z_pres_{}'.format(t),
                            Bernoulli(z_pres_p * state_z_pres).to_event(1),
                            infer=infer_dict)

            # sample_mask = cur_z_pres
            # size = [64, 3]
            cur_z_where =\
                pyro.sample('z_where_{}'.format(t),
                            Normal(z_where_loc + z_where_loc_prior,
                                   z_where_scale * z_where_scale_prior)
                            .mask(cur_z_pres)
                            .to_event(1))

            #===== image_to_window
            # images = batch

            #===== z_where_inv
            # size = [64, 3]
            out = torch.cat((torch.ones(64, 1), -cur_z_where[:, 1:]), 1)
            out = out / cur_z_where[:, 0:1]
            cur_z_where_inv = out
            #===== z_where_inv
            #===== expand_z_where
            # size = [64, 4]
            out = torch.cat((torch.zeros(64, 1), cur_z_where_inv), 1)
            # size = [64, 6]
            out = torch.index_select(out, 1, expansion_indices)
            out = torch.Tensor.view(out, [64, 2, 3])
            theta_inv = out
            #===== expand_z_where

            # size = [64, 28, 28, 2]
            grid = F.affine_grid(theta_inv, [64, 1, 28, 28])
            # size = [64, 1, 28, 28]
            out = F.grid_sample(torch.Tensor.view(batch, [64, 1, 50, 50]),
                                grid)

            x_att = torch.Tensor.view(out, [64, 784])
            #===== image_to_window

            #===== encode
            # size = [64, 100]
            a = encode_l2(F.relu(encode_l1(x_att)))
            # size = [64, 50]
            z_what_loc = a[:, 0:50]
            # size = [64, 50]
            z_what_scale = F.softplus(a[:, 50:])
            #===== encode

            # size = [64, 50]
            cur_z_what =\
                pyro.sample('z_what_{}'.format(t),
                            Normal(z_what_loc, z_what_scale)
                            .mask(cur_z_pres)
                            .to_event(1))

            # state_h = h
            # state_c = c
            # state_bl_h = bl_h
            # state_bl_c = bl_c
            state_z_pres = cur_z_pres
            state_z_where = cur_z_where
            state_z_what = cur_z_what
            #=========== guide_step

            z_where.append(state_z_where)
            z_pres.append(state_z_pres)

        return z_where, z_pres
示例#5
0
 def model():
     with pyro.plate_stack("plates", shape[:dim]):
         with pyro.plate("particles", 10000):
             pyro.sample(
                 "x",
                 dist.Normal(loc, scale).expand(shape).to_event(-dim))
示例#6
0
    def guide(self,
              x,
              temp_id=None,
              anneal_id=None,
              anneal_t=None,
              anneal_dynamics=None):
        pyro.module('vdsm_seq', self)
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        bs, seq_len, pixels = x.view(x.shape[0], x.shape[1],
                                     self.imsize**2 * self.nc).shape
        h_0_enc = self.h_0_enc.expand(2 * self.num_layers_rnn, bs,
                                      self.hid_dim).contiguous()
        c_0_enc = self.c_0_enc.expand(2 * self.num_layers_rnn, bs,
                                      self.hid_dim).contiguous()
        z_prev_ = self.z_q_0.expand(bs, 1, -1)  # z0
        dec_inp_0 = self.dec_inp_0.expand(bs, 1, -1)
        x = x.view(bs * seq_len, self.nc, self.imsize, self.imsize)

        pre_z, _, ID_loc, ID_scale = self.image_enc(x)
        ID_loc, ID_scale = self.id_layers(ID_loc,
                                          ID_scale)  # extra trainable layer
        ID_loc = torch.mean(ID_loc.view(bs, seq_len, -1), 1).unsqueeze(1)[:, 0]
        ID_scale = torch.mean(ID_scale.view(bs, seq_len, -1),
                              1).unsqueeze(1)[:, 0]
        pre_z = pre_z.view(bs, seq_len, -1)

        # from https://github.com/yatindandi/Disentangled-Sequential-Autoencoder/blob/master/model.py self.encode_f
        sequence = pre_z.permute(1, 0, 2)
        _, h, _, rnn_enc_raw, out = self.seq2seq_enc(sequence, h_0_enc,
                                                     c_0_enc)
        h = h.permute(1, 0, 2).contiguous()
        h = h.view(bs, self.hid_dim * 2 * self.num_layers_rnn)

        d_params = self.cats(h)
        dz_loc = self.act(d_params[:, :self.dynamics_dim])
        dz_scale = self.softplus(d_params[:, self.dynamics_dim:])

        # infer dynamics and identity from data
        with pyro.plate('ID_plate', bs):
            IDdist = dist.Normal(ID_loc, ID_scale).to_event(1)
            dz_dist = dist.Normal(dz_loc, dz_scale).to_event(1)
            with poutine.scale(scale=anneal_id):
                ID = pyro.sample('ID', IDdist) * temp_id  # static factors
            with poutine.scale(scale=anneal_dynamics):
                dz = pyro.sample("dz", dz_dist)  # dynamics z

        h_dec = self.dz_to_dec_h(dz).view(-1, self.num_layers_rnn,
                                          self.hid_dim).permute(1, 0, 2)
        c_dec = self.dz_to_dec_c(dz).view(-1, self.num_layers_rnn,
                                          self.hid_dim).permute(1, 0, 2)

        for i in pyro.plate('batch_loop', bs):
            dec_inp = dec_inp_0[None, i].contiguous()
            z_prev = z_prev_[None, i].contiguous()
            h = h_dec[:, None, i]
            c = c_dec[:, None, i]
            dz_dec = dz[None, i, None, :]
            for t in pyro.markov(range(seq_len)):
                dec_inp, (h, c) = self.seq2seq_dec(dec_inp, (h, c))
                z_loc, z_scale = self.comb(z_prev, dec_inp, dz_dec)
                z_dist = dist.Normal(z_loc[0], z_scale[0]).to_event(1)
                with poutine.scale(scale=anneal_t):
                    z = pyro.sample('z_{}_{}'.format(i, t), z_dist)
                z_prev = z.view(1, 1, -1)
示例#7
0
 def _fn(*args, **kwargs):
     with pyro.plate("num_particles_vectorized",
                     num_samples,
                     dim=-max_plate_nesting):
         return fn(*args, **kwargs)
示例#8
0
文件: dmm.py 项目: ucals/pyro
    def guide(self,
              mini_batch,
              mini_batch_reversed,
              mini_batch_mask,
              mini_batch_seq_lengths,
              annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)
        # register all PyTorch (sub)modules with pyro
        pyro.module("dmm", self)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0.expand(1, mini_batch.size(0),
                                     self.rnn.hidden_size).contiguous()
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

        # we enclose all the sample statements in the guide in a plate.
        # this marks that each datapoint is conditionally independent of the others.
        with pyro.plate("z_minibatch", len(mini_batch)):
            # sample the latents z one time step at a time
            # we wrap this loop in pyro.markov so that TraceEnum_ELBO can use multiple samples from the guide at each z
            for t in pyro.markov(range(1, T_max + 1)):
                # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])

                # if we are using normalizing flows, we apply the sequence of transformations
                # parameterized by self.iafs to the base distribution defined in the previous line
                # to yield a transformed distribution that we use for q(z_t|...)
                if len(self.iafs) > 0:
                    z_dist = TransformedDistribution(
                        dist.Normal(z_loc, z_scale), self.iafs)
                    assert z_dist.event_shape == (self.z_q_0.size(0), )
                    assert z_dist.batch_shape[-1:] == (len(mini_batch), )
                else:
                    z_dist = dist.Normal(z_loc, z_scale)
                    assert z_dist.event_shape == ()
                    assert z_dist.batch_shape[-2:] == (len(mini_batch),
                                                       self.z_q_0.size(0))

                # sample z_t from the distribution z_dist
                with pyro.poutine.scale(scale=annealing_factor):
                    if len(self.iafs) > 0:
                        # in output of normalizing flow, all dimensions are correlated (event shape is not empty)
                        z_t = pyro.sample(
                            "z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1]))
                    else:
                        # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent
                        z_t = pyro.sample(
                            "z_%d" % t,
                            z_dist.mask(mini_batch_mask[:,
                                                        t - 1:t]).to_event(1))
                # the latent sampled at this time step will be conditioned upon in the next time step
                # so keep track of it
                z_prev = z_t
示例#9
0
 def model():
     with pyro.plate("particles", 20000):
         return pyro.sample("x", dist.Stable(stability, skew))
    def guide(self, mini_batch, mini_batch_reversed, annealing_factor=1.0):
        """
        The inference model q(z_{1:T} | y_{1:T})
        """

        # Number of time steps through mini-batch
        T_max = mini_batch.size(1)

        # Register all PyTorch modules with Pyro
        pyro.module("dkf", self)

        # Contiguous hidden state
        h_0 = self.h_0.expand(1, mini_batch.size(0),
                              self.rnn.hidden_size).contiguous()

        # Flatten and reverse input y
        batch_size = mini_batch_reversed.shape[0]
        seq_len = mini_batch_reversed.shape[1]
        flat_mini_batch_reversed = torch.zeros(
            batch_size, seq_len, self.rnn_input_dim).to(self.device)
        for t in range(seq_len):
            flat_mini_batch_reversed[:, t, :] = self.flatten(
                mini_batch_reversed[:, t, :, :, :])

        # Feed y through RNN;
        rnn_output, _ = self.rnn(flat_mini_batch_reversed, h_0)

        # Backwards to take future observations into account
        rnn_output = reversed_input(rnn_output)

        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

        # We enclose all the sample statements in the model in a plate
        # for conditional independence
        with pyro.plate("z_test", len(mini_batch)):
            # Sample the latents z
            for t in range(1, T_max + 1):
                # Mean and variance for the distribution q(z_t | z_{t-1}, y_{t:T})
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])

                # If we are using normalizing flows, we apply the sequence of transformations
                # parameterized by self.iafs to the base distribution
                if len(self.iafs) > 0:
                    z_dist = TransformedDistribution(
                        dist.Normal(z_loc, z_scale), self.iafs)
                else:
                    z_dist = dist.Normal(z_loc, z_scale)

                # Sample z_t from the distribution z_dist
                with pyro.poutine.scale(scale=annealing_factor):
                    if len(self.iafs) > 0:
                        z_t = pyro.sample("z_%d" % t, z_dist)
                    else:
                        # When no normalizing flow is used, ".to_event(1)"
                        # indicates latent dimensions are independent
                        z_t = pyro.sample("z_%d" % t, z_dist.to_event(1))

                # Update time step
                z_prev = z_t

        return z_t
示例#11
0
 def forward(self, data):
     loc, log_scale = self.z.unbind(-1)
     with pyro.plate("data"):
         pyro.sample("obs", dist.Cauchy(loc, log_scale.exp()), obs=data)
示例#12
0
def iplate_cuda_model(subsample_size):
    loc = torch.zeros(20).cuda()
    scale = torch.ones(20).cuda()
    for i in pyro.plate("data", 20, subsample_size, device=loc.device):
        pyro.sample("x_{}".format(i), dist.Normal(loc[i], scale[i]))
示例#13
0
def plate_cuda_model(subsample_size):
    loc = torch.zeros(20).cuda()
    scale = torch.ones(20).cuda()
    with pyro.plate("data", 20, subsample_size, device=loc.device) as batch:
        pyro.sample("x", dist.Normal(loc[batch], scale[batch]))
示例#14
0
def iplate_custom_model(subsample):
    result = []
    for i in pyro.plate('plate', 20, subsample=subsample):
        result.append(i)
    return result
示例#15
0
    def fit(self,
            model_name,
            model_param_names,
            data_input,
            fitter=None,
            init_values=None):
        # verbose is passed through from orbit.models.base_estimator
        verbose = self.verbose
        message = self.message
        learning_rate = self.learning_rate
        learning_rate_total_decay = self.learning_rate_total_decay
        num_sample = self.num_sample
        seed = self.seed
        num_steps = self.num_steps

        pyro.set_rng_seed(seed)
        if fitter is None:
            fitter = get_pyro_model(model_name)  # abstract
        model = fitter(data_input)  # concrete

        # Perform stochastic variational inference using an auto guide.
        pyro.clear_param_store()
        guide = AutoLowRankMultivariateNormal(model)
        optim = ClippedAdam({
            "lr": learning_rate,
            "lrd": learning_rate_total_decay**(1 / num_steps)
        })
        elbo = Trace_ELBO(num_particles=self.num_particles,
                          vectorize_particles=True)
        svi = SVI(model, guide, optim, elbo)

        for step in range(num_steps):
            loss = svi.step()
            if verbose and step % message == 0:
                scale_rms = guide._loc_scale()[1].detach().pow(
                    2).mean().sqrt().item()
                print("step {: >4d} loss = {:0.5g}, scale = {:0.5g}".format(
                    step, loss, scale_rms))

        # Extract samples.
        vectorize = pyro.plate("samples",
                               num_sample,
                               dim=-1 - model.max_plate_nesting)
        with pyro.poutine.trace() as tr:
            samples = vectorize(guide)()
        with pyro.poutine.replay(trace=tr.trace):
            samples.update(vectorize(model)())

        # Convert from torch.Tensors to numpy.ndarrays.
        extract = {
            name: value.detach().squeeze().numpy()
            for name, value in samples.items()
        }

        # make sure that model param names are a subset of stan extract keys
        invalid_model_param = set(model_param_names) - set(list(
            extract.keys()))
        if invalid_model_param:
            raise EstimatorException(
                "Pyro model definition does not contain required parameters")

        # `stan.optimizing` automatically returns all defined parameters
        # filter out unnecessary keys
        extract = {param: extract[param] for param in model_param_names}

        return extract
示例#16
0
 def model():
     with pyro.plate("plate", 10):
         with poutine.reparam(config={"x": Reparam()}):
             return pyro.sample("x", dist.Stable(1.5, 0))
示例#17
0
文件: SPIRE.py 项目: ianmbus/XID_plus
def spire_model(priors, sub=1):

    if len(priors) != 3:
        raise ValueError
    band_plate = pyro.plate('bands', len(priors), dim=-2)
    src_plate = pyro.plate('nsrc', priors[0].nsrc, dim=-1)
    psw_plate = pyro.plate('psw_pixels',
                           priors[0].sim.size,
                           dim=-3,
                           subsample_size=np.rint(
                               sub * priors[0].sim.size).astype(int))
    pmw_plate = pyro.plate('pmw_pixels',
                           priors[1].sim.size,
                           dim=-3,
                           subsample_size=np.rint(
                               sub * priors[1].sim.size).astype(int))
    plw_plate = pyro.plate('plw_pixels',
                           priors[2].sim.size,
                           dim=-3,
                           subsample_size=np.rint(
                               sub * priors[2].sim.size).astype(int))
    pointing_matrices = [
        torch.sparse.FloatTensor(torch.LongTensor([p.amat_row, p.amat_col]),
                                 torch.Tensor(p.amat_data),
                                 torch.Size([p.snpix, p.nsrc])) for p in priors
    ]

    bkg_prior = torch.tensor([p.bkg[0] for p in priors])
    bkg_prior_sig = torch.tensor([p.bkg[1] for p in priors])
    nsrc = priors[0].nsrc

    f_low_lim = torch.tensor([p.prior_flux_lower for p in priors],
                             dtype=torch.float)
    f_up_lim = torch.tensor([p.prior_flux_upper for p in priors],
                            dtype=torch.float)

    with band_plate as ind_band:
        sigma_conf = pyro.sample(
            'sigma_conf',
            dist.HalfCauchy(torch.tensor([1.0]), torch.tensor([0.5])).expand(
                [1]).to_event(1)).squeeze(-1)
        bkg = pyro.sample('bkg',
                          dist.Normal(-5,
                                      0.5).expand([1]).to_event(1)).squeeze(-1)
        with src_plate as ind_src:
            src_f = pyro.sample('src_f',
                                dist.Uniform(0, 1).expand(
                                    [1]).to_event(1)).squeeze(-1)
    f_vec = (f_up_lim - f_low_lim) * src_f + f_low_lim
    db_hat_psw = torch.sparse.mm(pointing_matrices[0],
                                 f_vec[0, ...].unsqueeze(-1)) + bkg[0]
    db_hat_pmw = torch.sparse.mm(pointing_matrices[1].to_dense(),
                                 f_vec[1, ...].unsqueeze(-1)) + bkg[1]
    db_hat_plw = torch.sparse.mm(pointing_matrices[2].to_dense(),
                                 f_vec[2, ...].unsqueeze(-1)) + bkg[2]
    sigma_tot_psw = torch.sqrt(
        torch.pow(torch.tensor(priors[0].snim), 2) +
        torch.pow(sigma_conf[0], 2))
    sigma_tot_pmw = torch.sqrt(
        torch.pow(torch.tensor(priors[1].snim), 2) +
        torch.pow(sigma_conf[1], 2))
    sigma_tot_plw = torch.sqrt(
        torch.pow(torch.tensor(priors[2].snim), 2) +
        torch.pow(sigma_conf[2], 2))
    with psw_plate as ind_psw:
        psw_map = pyro.sample("obs_psw",
                              dist.Normal(db_hat_psw.squeeze()[ind_psw],
                                          sigma_tot_psw[ind_psw]),
                              obs=torch.tensor(priors[0].sim[ind_psw]))
    with pmw_plate as ind_pmw:
        pmw_map = pyro.sample("obs_pmw",
                              dist.Normal(db_hat_pmw.squeeze()[ind_pmw],
                                          sigma_tot_pmw[ind_pmw]),
                              obs=torch.tensor(priors[1].sim[ind_pmw]))
    with plw_plate as ind_plw:
        plw_map = pyro.sample("obs_plw",
                              dist.Normal(db_hat_plw.squeeze()[ind_plw],
                                          sigma_tot_plw[ind_plw]),
                              obs=torch.tensor(priors[2].sim[ind_plw]))
    return psw_map, pmw_map, plw_map
示例#18
0
 def create_plates():
     return pyro.plate("plate", 10, subsample_size=3)
示例#19
0
文件: neutra.py 项目: nwjnwj/pyro
def main(args):
    pyro.set_rng_seed(args.rng_seed)
    fig = plt.figure(figsize=(8, 16), constrained_layout=True)
    gs = GridSpec(4, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[2, 0])
    ax5 = fig.add_subplot(gs[3, 0])
    ax6 = fig.add_subplot(gs[1, 1])
    ax7 = fig.add_subplot(gs[2, 1])
    ax8 = fig.add_subplot(gs[3, 1])
    xlim = tuple(int(x) for x in args.x_lim.strip().split(','))
    ylim = tuple(int(x) for x in args.y_lim.strip().split(','))
    assert len(xlim) == 2
    assert len(ylim) == 2

    # 1. Plot samples drawn from BananaShaped distribution
    x1, x2 = torch.meshgrid([torch.linspace(*xlim, 100), torch.linspace(*ylim, 100)])
    d = BananaShaped(args.param_a, args.param_b)
    p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1)))
    ax1.contourf(x1, x2, p, cmap='OrRd',)
    ax1.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='BananaShaped distribution: \nlog density')

    # 2. Run vanilla HMC
    logging.info('\nDrawing samples using vanilla HMC ...')
    mcmc = run_hmc(args, model)
    vanilla_samples = mcmc.get_samples()['x'].cpu().numpy()
    ax2.contourf(x1, x2, p, cmap='OrRd')
    ax2.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior \n(vanilla HMC)')
    sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2)

    # 3(a). Fit a diagonal normal autoguide
    logging.info('\nFitting a DiagNormal autoguide ...')
    guide = AutoDiagonalNormal(model, init_scale=0.05)
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax3.contourf(x1, x2, p, cmap='OrRd')
    ax3.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior \n(DiagNormal autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3)

    # 3(b). Draw samples using NeuTra HMC
    logging.info('\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4)
    ax4.set(xlabel='x0', ylabel='x1',
            title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax5.contourf(x1, x2, p, cmap='OrRd')
    ax5.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5)

    # 4(a). Fit a BNAF autoguide
    logging.info('\nFitting a BNAF autoguide ...')
    guide = AutoNormalizingFlow(model, partial(iterated, args.num_flows, block_autoregressive))
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax6.contourf(x1, x2, p, cmap='OrRd')
    ax6.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior \n(BNAF autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6)

    # 4(b). Draw samples using NeuTra HMC
    logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7)
    ax7.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(BNAF + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax8.contourf(x1, x2, p, cmap='OrRd')
    ax8.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim,
            title='Posterior (transformed) \n(BNAF + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8)

    plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf'))
示例#20
0
 def model():
     with pyro.plate_stack("plates", shape):
         with pyro.plate("particles", 200000):
             return pyro.sample("x", dist.Stable(stability, 0, scale, loc))
示例#21
0
 def guide(self, data):
     pyro.module("encoder", self.encoder)
     x_observation = data[0]
     with pyro.plate("data", x_observation.shape[0]):
         z_loc, z_scale = self.encoder.forward(x_observation)
         pyro.sample('latent', dist.Normal(z_loc, z_scale).to_event(1))
示例#22
0
    def sample(self, n_samples=1):
        with pyro.plate('observations', n_samples):
            samples = self.model()

        return (*samples, )
示例#23
0
def model(data):
    data = torch.reshape(data, [60000, 50, 50])

    pyro.module("decode_l1", decode_l1)
    pyro.module("decode_l2", decode_l2)

    with pyro.plate('data', 60000, 64) as ix:
        # size = [64, 50, 50]
        batch = data[ix]

        #================= prior
        state_x = torch.zeros([64, 50, 50])
        state_z_pres = torch.ones([64, 1])
        state_z_where = None

        z_pres = []
        z_where = []

        for t in range(3):
            #==================== prior_step
            # size = [64, 50, 50]
            prev_x = state_x
            # size = [64, 1]
            prev_z_pres = state_z_pres
            # size = None or [64, 3]
            prev_z_where = state_z_where

            # size = [64, 1]
            cur_z_pres =\
                pyro.sample('z_pres_{}'.format(t),
                            Bernoulli(trial_probs[t] * prev_z_pres)
                            .to_event(1))

            sample_mask = cur_z_pres
            # size = [64, 3]
            cur_z_where =\
                pyro.sample('z_where_{}'.format(t),
                            Normal(torch.Tensor.expand(z_where_loc_prior, [64, 3]),
                                   torch.Tensor.expand(z_where_scale_prior, [64, 3]))
                            .mask(sample_mask)
                            .to_event(1))

            # size = [64, 50]
            cur_z_what =\
                pyro.sample('z_what_{}'.format(t),
                            Normal(torch.zeros([64, 50]),
                                   torch.ones([64, 50]))
                            .mask(sample_mask)
                            .to_event(1))

            #===== decode
            # size = [64, 784]
            y_att = torch.sigmoid(
                decode_l2(F.relu(decode_l1(cur_z_what))) - 2.0)
            #===== decode

            #===== window_to_image
            windows = y_att

            #===== expand_z_where
            # size = [64, 4]
            out = torch.cat((torch.zeros(64, 1), cur_z_where), 1)
            # size = [64, 6]
            out = torch.index_select(out, 1, expansion_indices)
            # size = [64, 2, 3]
            out = torch.Tensor.view(out, [64, 2, 3])
            theta = out
            #===== expand_z_where
            # size = [64, 50, 50, 2]
            grid = F.affine_grid(theta, [64, 1, 50, 50])
            # size = [64, 1, 50, 50]
            out = F.grid_sample(torch.Tensor.view(windows, [64, 1, 28, 28]),
                                grid)

            y = torch.Tensor.view(out, [64, 50, 50])
            #===== window_to_image

            # size = [64, 50, 50]
            cur_x = prev_x + (y * torch.Tensor.view(cur_z_pres, [64, 1, 1]))

            state_x = cur_x
            state_z_pres = cur_z_pres
            state_z_where = cur_z_where
            #==================== prior_step

            z_where.append(state_z_where)
            z_pres.append(state_z_pres)

        # size = [64, 50, 50]
        x = state_x
        #================== prior

        pyro.sample('obs',
                    Normal(torch.Tensor.view(x, [64, 2500]),
                           (0.3 * torch.ones(64, 2500))).to_event(1),
                    obs=torch.Tensor.view(batch, [64, 2500]))
示例#24
0
    def sample_scm(self, n_samples=1):
        with pyro.plate('observations', n_samples):
            samples = self.scm()

        return (*samples, )
示例#25
0
 def model():
     with pyro.plate_stack("plates", shape[:dim]):
         with pyro.plate("particles", 10000):
             pyro.sample("x",
                         dist.Uniform(0, 1).expand(shape).to_event(-dim))
示例#26
0
 def sample_pgm(num_samples):
     with pyro.plate('observations', num_samples):
         return self.pyro_model.pgm_model()
示例#27
0
 def model(data):
     y_prob = pyro.sample("y_prob", dist.Beta(1., 1.))
     with pyro.plate("data", data.shape[0]):
         y = pyro.sample("y", dist.Bernoulli(y_prob))
         z = pyro.sample("z", dist.Bernoulli(0.65 * y + 0.1))
         pyro.sample("obs", dist.Normal(2. * z, 1.), obs=data)
示例#28
0
 def model():
     lambda_latent = pyro.sample("lambda_latent", Gamma(alpha0, beta0))
     with pyro.plate("data", n_data):
         pyro.sample("obs", dist.Poisson(lambda_latent), obs=data)
     return lambda_latent
示例#29
0
 def model():
     with pyro.plate("data", len(data), subsample_size) as ind:
         x = data[ind]
         z = pyro.sample("z", Normal(0, 1).expand_by(x.shape))
         pyro.sample("x", Normal(z, 1), obs=x)
示例#30
0
def neals_funnel(dim=10):
    y = pyro.sample("y", dist.Normal(0, 3))
    with pyro.plate("D", dim):
        return pyro.sample("x", dist.Normal(0, torch.exp(y / 2)))