示例#1
0
 def model(data, mask):
     with numpyro.plate('N', N):
         x = numpyro.sample('x', dist.Normal(0, 1))
         with handlers.mask(mask=mask):
             numpyro.sample('y', dist.Delta(x, log_density=1.))
             with handlers.scale(scale=2):
                 numpyro.sample('obs', dist.Normal(x, 1), obs=data)
示例#2
0
 def model(data, mask):
     with numpyro.plate("N", N):
         x = numpyro.sample("x", dist.Normal(0, 1))
         with handlers.mask(mask=mask):
             numpyro.sample("y", dist.Delta(x, log_density=1.0))
             with handlers.scale(scale=2):
                 numpyro.sample("obs", dist.Normal(x, 1), obs=data)
示例#3
0
 def scaled_loss(rng_key, classic_params, stein_params):
     params = {**classic_params, **stein_params}
     loss_val = self.loss.loss(
         rng_key, params,
         handlers.scale(self.model, self.loss_temperature), self.guide,
         *args, **kwargs, **self.static_kwargs)
     return -loss_val
示例#4
0
def minibatch(batch_or_batchsize, num_obs_total=None):
    """Returns a context within which all samples are treated as being a
    minibatch of a larger data set.

    In essence, this marks the (log)likelihood of the sampled examples to be
    scaled to the total loss value over the whole data set.

    :param batch_or_batchsize: An integer indicating the batch size or an array
        indicating the shape of the batch where the length of the first axis
        is interpreted as batch size.
    :param num_obs_total: The total number of examples/observations in the
        full data set. Optional, defaults to the given batch size.
    """
    if is_int_scalar(batch_or_batchsize):
        if not jnp.isscalar(batch_or_batchsize):
            raise TypeError(
                "if a scalar is given for batch_or_batchsize, it "
                "can't be traced through jit. consider using static_argnums "
                "for the jit invocation.")
        batch_size = batch_or_batchsize
    elif is_array(batch_or_batchsize):
        batch_size = example_count(batch_or_batchsize)
    else:
        raise TypeError("batch_or_batchsize must be an array or an integer")
    if num_obs_total is None:
        num_obs_total = batch_size
    return scale(scale=num_obs_total / batch_size)
示例#5
0
def test_scale(use_context_manager):
    def model(data):
        x = numpyro.sample('x', dist.Normal(0, 1))
        with optional(use_context_manager, handlers.scale(scale=10)):
            numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    model = model if use_context_manager else handlers.scale(model, 10.)
    data = random.normal(random.PRNGKey(0), (3, ))
    x = random.normal(random.PRNGKey(1))
    log_joint = log_density(model, (data, ), {}, {'x': x})[0]
    log_prob1, log_prob2 = dist.Normal(0, 1).log_prob(x), dist.Normal(
        x, 1).log_prob(data).sum()
    expected = log_prob1 + 10 * log_prob2 if use_context_manager else 10 * (
        log_prob1 + log_prob2)
    assert_allclose(log_joint, expected)
示例#6
0
 def model(data):
     x = numpyro.sample('x', dist.Normal(0, 1))
     with optional(use_context_manager, handlers.scale(scale=10)):
         numpyro.sample('obs', dist.Normal(x, 1), obs=data)
示例#7
0
def test_subsample_gradient(scale, subsample):
    data = jnp.array([-0.5, 2.0])
    subsample_size = 1 if subsample else len(data)
    precision = 0.06 * scale

    def model(subsample):
        with handlers.substitute(data={"data": subsample}):
            with numpyro.plate("data", len(data), subsample_size) as ind:
                x = data[ind]
                z = numpyro.sample("z", dist.Normal(0, 1))
                numpyro.sample("x", dist.Normal(z, 1), obs=x)

    def guide(subsample):
        scale = numpyro.param("scale", 1.)
        with handlers.substitute(data={"data": subsample}):
            with numpyro.plate("data", len(data), subsample_size):
                loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0)
                numpyro.sample("z", dist.Normal(loc, scale))

    if scale != 1.:
        model = handlers.scale(model, scale=scale)
        guide = handlers.scale(guide, scale=scale)

    num_particles = 50000
    optimizer = optim.Adam(0.1)
    elbo = Trace_ELBO(num_particles=num_particles)
    svi = SVI(model, guide, optimizer, loss=elbo)
    svi_state = svi.init(random.PRNGKey(0), None)
    params = svi.optim.get_params(svi_state.optim_state)
    normalizer = 2 if subsample else 1
    if subsample_size == 1:
        subsample = jnp.array([0])
        loss1, grads1 = value_and_grad(
            lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi
                                    .model, svi.guide, subsample))(params)
        subsample = jnp.array([1])
        loss2, grads2 = value_and_grad(
            lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi
                                    .model, svi.guide, subsample))(params)
        grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2)
        loss = loss1 + loss2
    else:
        subsample = jnp.array([0, 1])
        loss, grads = value_and_grad(
            lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi
                                    .model, svi.guide, subsample))(params)

    actual_loss = loss / normalizer
    expected_loss, _ = value_and_grad(lambda x: svi.loss.loss(
        svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, None))(
            params)
    assert_allclose(actual_loss, expected_loss, rtol=precision, atol=precision)

    actual_grads = {name: grad / normalizer for name, grad in grads.items()}
    expected_grads = {
        'loc': scale * jnp.array([0.5, -2.0]),
        'scale': scale * jnp.array([2.0])
    }
    assert actual_grads.keys() == expected_grads.keys()
    for name in expected_grads:
        assert_allclose(actual_grads[name],
                        expected_grads[name],
                        rtol=precision,
                        atol=precision)
示例#8
0
 def model(data):
     x = sample('x', dist.Normal(0, 1))
     with scale(10):
         sample('obs', dist.Normal(x, 1), obs=data)
示例#9
0
文件: vae.py 项目: byzhang/d3p
def main(args):
    # loading data
    (train_init, train_fetch_plain), num_samples = load_dataset(
        MNIST,
        batch_size=args.batch_size,
        split='train',
        batchifier=subsample_batchify_data)
    (test_init,
     test_fetch_plain), _ = load_dataset(MNIST,
                                         batch_size=args.batch_size,
                                         split='test',
                                         batchifier=split_batchify_data)

    def binarize_fetch(fetch_fn):
        @jit
        def fetch_binarized(batch_nr, batchifier_state, binarize_rng):
            batch = fetch_fn(batch_nr, batchifier_state)
            return binarize(binarize_rng, batch[0]), batch[1]

        return fetch_binarized

    train_fetch = binarize_fetch(train_fetch_plain)
    test_fetch = binarize_fetch(test_fetch_plain)

    # setting up optimizer
    optimizer = optimizers.Adam(args.learning_rate)

    # the minibatch environment in our model scales individual
    # records' contributions to the loss up by num_samples.
    # This can cause numerical instabilities so we scale down
    # the loss by 1/num_samples here.
    sc_model = scale(model, scale=1 / num_samples)
    sc_guide = scale(guide, scale=1 / num_samples)

    if args.no_dp:
        svi = SVI(sc_model,
                  sc_guide,
                  optimizer,
                  ELBO(),
                  num_obs_total=num_samples,
                  z_dim=args.z_dim,
                  hidden_dim=args.hidden_dim)
    else:
        q = args.batch_size / num_samples
        target_eps = args.epsilon
        dp_scale, act_eps, _ = approximate_sigma(target_eps=target_eps,
                                                 delta=1 / num_samples,
                                                 q=q,
                                                 num_iter=int(1 / q) *
                                                 args.num_epochs,
                                                 force_smaller=True)
        print(
            f"using noise scale {dp_scale} for epsilon of {act_eps} (targeted: {target_eps})"
        )

        svi = DPSVI(sc_model,
                    sc_guide,
                    optimizer,
                    ELBO(),
                    dp_scale=dp_scale,
                    clipping_threshold=10.,
                    num_obs_total=num_samples,
                    z_dim=args.z_dim,
                    hidden_dim=args.hidden_dim)

    # preparing random number generators and initializing svi
    rng = PRNGKey(0)
    rng, binarize_rng, svi_init_rng, batchifier_rng = random.split(rng, 4)
    _, batchifier_state = train_init(rng_key=batchifier_rng)
    sample_batch = train_fetch(0, batchifier_state, binarize_rng)[0]
    svi_state = svi.init(svi_init_rng, sample_batch)

    # functions for training tasks
    @jit
    def epoch_train(svi_state, batchifier_state, num_batches, rng):
        """Trains one epoch

        :param svi_state: current state of the optimizer
        :param rng: rng key

        :return: overall training loss over the epoch
        """
        def body_fn(i, val):
            svi_state, loss = val
            binarize_rng = random.fold_in(rng, i)
            batch = train_fetch(i, batchifier_state, binarize_rng)[0]
            svi_state, batch_loss = svi.update(svi_state, batch)
            loss += batch_loss / num_batches
            return svi_state, loss

        svi_state, loss = lax.fori_loop(0, num_batches, body_fn,
                                        (svi_state, 0.))
        return svi_state, loss

    @jit
    def eval_test(svi_state, batchifier_state, num_batches, rng):
        """Evaluates current model state on test data.

        :param svi_state: current state of the optimizer
        :param rng: rng key

        :return: loss over the test split
        """
        def body_fn(i, loss_sum):
            binarize_rng = random.fold_in(rng, i)
            batch = test_fetch(i, batchifier_state, binarize_rng)[0]
            batch_loss = svi.evaluate(svi_state, batch)
            loss_sum += batch_loss / num_batches
            return loss_sum

        return lax.fori_loop(0, num_batches, body_fn, 0.)

    def reconstruct_img(epoch, num_epochs, batchifier_state, svi_state, rng):
        """Reconstructs an image for the given epoch

        Obtains a sample from the testing data set and passes it through the
        VAE. Stores the result as image file 'epoch_{epoch}_recons.png' and
        the original input as 'epoch_{epoch}_original.png' in folder '.results'.

        :param epoch: Number of the current epoch
        :param num_epochs: Number of total epochs
        :param opt_state: Current state of the optimizer
        :param rng: rng key
        """
        assert (num_epochs > 0)
        img = test_fetch_plain(0, batchifier_state)[0][0]
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_original.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   img,
                   cmap='gray')
        rng, rng_binarize = random.split(rng, 2)
        test_sample = binarize(rng_binarize, img)
        test_sample = jnp.reshape(test_sample, (1, *jnp.shape(test_sample)))
        params = svi.get_params(svi_state)

        samples = sample_multi_posterior_predictive(
            rng, 10, model,
            (1, args.z_dim, args.hidden_dim, np.prod(test_sample.shape[1:])),
            guide, (test_sample, args.z_dim, args.hidden_dim), params)

        img_loc = samples['obs'][0].reshape([28, 28])
        avg_img_loc = jnp.mean(samples['obs'], axis=0).reshape([28, 28])
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_recons_single.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   img_loc,
                   cmap='gray')
        plt.imsave(os.path.join(
            RESULTS_DIR, "epoch_{:0{}d}_recons_avg.png".format(
                epoch, (int(jnp.log10(num_epochs)) + 1))),
                   avg_img_loc,
                   cmap='gray')

    # main training loop
    for i in range(args.num_epochs):
        t_start = time.time()
        rng, data_fetch_rng, train_rng = random.split(rng, 3)
        num_train_batches, train_batchifier_state, = train_init(
            rng_key=data_fetch_rng)
        svi_state, train_loss = epoch_train(svi_state, train_batchifier_state,
                                            num_train_batches, train_rng)

        rng, test_fetch_rng, test_rng, recons_rng = random.split(rng, 4)
        num_test_batches, test_batchifier_state = test_init(
            rng_key=test_fetch_rng)
        test_loss = eval_test(svi_state, test_batchifier_state,
                              num_test_batches, test_rng)

        reconstruct_img(i, args.num_epochs, test_batchifier_state, svi_state,
                        recons_rng)
        print("Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format(
            i, test_loss, train_loss,
            time.time() - t_start))