Ejemplo n.º 1
0
def test_reshape_shape_forward(filters):

    n_dims = (1, 28, 28, 1)
    new_shape = _get_new_shapes(28, 28, 1, filters)
    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.uniform(data_rng, shape=n_dims)

    # create layer
    init_func = Squeeze(filter_shape=filters, collapse=None, return_outputs=True)

    # create layer
    z_, params, forward_f, inverse_f = init_func(rng=params_rng, shape=n_dims, inputs=x)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_tree_all_close(z, z_)
    chex.assert_equal_shape([z, log_abs_det, z_])
    chex.assert_rank(z, 4)
    chex.assert_equal(z.shape[1:], new_shape)

    # inverse transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
    chex.assert_tree_all_close(x_approx, x)
Ejemplo n.º 2
0
def test_numpyro_marginal_ll_tfp_priors_type(n_samples, n_features, n_latents, dtype):

    # create sample data
    ds = _gen_training_data(n_samples, n_features, n_latents)

    # convert to tyle
    ds = jax.tree_util.tree_map(lambda x: x.astype(dtype), ds)

    # initialize parameters
    params, posterior = _get_conjugate_posterior_params()

    # convert to numpyro-style params
    numpyro_params = numpyro_dict_params(params)

    # convert to priors
    numpyro_params = add_priors(numpyro_params, tfd.LogNormal(0.0, 10.0))

    # initialize numpyro-style GP model
    npy_model = numpyro_marginal_ll(posterior, numpyro_params)

    # do one forward pass with context
    with numpyro.handlers.seed(rng_seed=KEY):
        pred = npy_model(ds)

        chex.assert_equal(pred.dtype, ds.y.dtype)
Ejemplo n.º 3
0
  def test_simple(self):
    """Check if the leaf key is `a`."""
    mask = create_mask(lambda path, _: path[-1] == 'a')
    data = {'a': 4, 'b': {'a': 5, 'c': 1}, 'c': {'a': {'b': 1}}}

    truth = {'a': True, 'b': {'a': True, 'c': False}, 'c': {'a': {'b': False}}}

    chex.assert_equal(mask(data), truth)
    def test_with_different_event_ndims(self):
        dx_bij = Lambda(forward=lambda x: x.reshape(x.shape[:-1] + (2, 3)),
                        inverse=lambda y: y.reshape(y.shape[:-2] + (6, )),
                        forward_log_det_jacobian=lambda _: 0,
                        inverse_log_det_jacobian=lambda _: 0,
                        is_constant_jacobian=True,
                        event_ndims_in=1,
                        event_ndims_out=2)
        tfp_bij = tfp_compatible_bijector(dx_bij)

        with self.subTest('forward_event_ndims'):
            assert tfp_bij.forward_event_ndims(1) == 2
            assert tfp_bij.forward_event_ndims(2) == 3

        with self.subTest('inverse_event_ndims'):
            assert tfp_bij.inverse_event_ndims(2) == 1
            assert tfp_bij.inverse_event_ndims(3) == 2

        with self.subTest('forward_event_ndims with incorrect input'):
            with self.assertRaises(ValueError):
                tfp_bij.forward_event_ndims(0)

        with self.subTest('inverse_event_ndims with incorrect input'):
            with self.assertRaises(ValueError):
                tfp_bij.inverse_event_ndims(0)

            with self.assertRaises(ValueError):
                tfp_bij.inverse_event_ndims(1)

        with self.subTest('forward_event_shape'):
            y_shape = tfp_bij.forward_event_shape((6, ))
            y_shape_tensor = tfp_bij.forward_event_shape_tensor((6, ))
            self.assertEqual(y_shape, (2, 3))
            np.testing.assert_array_equal(y_shape_tensor, jnp.array((2, 3)))

        with self.subTest('inverse_event_shape'):
            x_shape = tfp_bij.inverse_event_shape((2, 3))
            x_shape_tensor = tfp_bij.inverse_event_shape_tensor((2, 3))
            self.assertEqual(x_shape, (6, ))
            np.testing.assert_array_equal(x_shape_tensor, jnp.array((6, )))

        with self.subTest('TransformedDistribution with correct event_ndims'):
            base = tfd.MultivariateNormalDiag(np.zeros(6), np.ones(6))
            dist = tfd.TransformedDistribution(base, tfp_bij)
            chex.assert_equal(dist.event_shape, (2, 3))

            sample = dist.sample(seed=jax.random.PRNGKey(0))
            chex.assert_shape(sample, (2, 3))

            log_prob = dist.log_prob(sample)
            chex.assert_shape(log_prob, ())

        with self.subTest(
                'TransformedDistribution with incorrect event_ndims'):
            base = tfd.Normal(np.zeros(6), np.ones(6))
            dist = tfd.TransformedDistribution(base, tfp_bij)
            with self.assertRaises(ValueError):
                _ = dist.event_shape
Ejemplo n.º 5
0
def test_numpyro_add_constraints_str(variable, constraint):

    gpjax_params = _get_conjugate_posterior_params()
    numpyro_params = numpyro_dict_params(gpjax_params)

    # add constraint
    new_numpyro_params = add_constraints(numpyro_params, variable, constraint)

    # check if constraint in new dictionary
    chex.assert_equal(new_numpyro_params[variable]["constraint"], constraint)

    # check we didn't modify original dictionary
    chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
Ejemplo n.º 6
0
def test_numpyro_dict_params_defaults_float():

    demo_params = {
        "lengthscale": 1.0,
        "variance": 1.0,
        "obs_noise": 1.0,
    }

    numpyro_params = numpyro_dict_params(demo_params)

    assert set(numpyro_params) == set(demo_params.keys())
    for ikey, iparam in demo_params.items():
        # check keys exist for param
        assert set(numpyro_params[ikey].keys()) == set(
            ("init_value", "constraint", "param_type"))
        # check init value is the same as initial value
        chex.assert_equal(numpyro_params[ikey]["init_value"], iparam)
        # check default constraint is positive
        chex.assert_equal(numpyro_params[ikey]["constraint"],
                          constraints.positive)
        # check if param type is param
        chex.assert_equal(numpyro_params[ikey]["param_type"], "param")

    # check we didn't modify original dictionary
    chex.assert_equal(
        demo_params,
        {
            "lengthscale": 1.0,
            "variance": 1.0,
            "obs_noise": 1.0,
        },
    )
Ejemplo n.º 7
0
def test_numpyro_add_constraints_all(constraint):

    gpjax_params = _get_conjugate_posterior_params()
    numpyro_params = numpyro_dict_params(gpjax_params)

    # add constraint
    new_numpyro_params = add_constraints(numpyro_params, constraint)
    for iparams in new_numpyro_params.values():

        # check if constraint in new dictionary
        chex.assert_equal(iparams["constraint"], constraint)

    # check we didn't modify original dictionary
    chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
Ejemplo n.º 8
0
  def test_batched_bijector_shapes(self, batch_shape, sample_shape):
    base = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3))
    bijector = block.Block(tfb.Scale(jnp.ones(batch_shape + (3,))), 1)
    dist = transformed.Transformed(base, bijector)

    with self.subTest('batch_shape'):
      chex.assert_equal(dist.batch_shape, batch_shape)

    with self.subTest('sample.shape'):
      sample = dist.sample(seed=self.seed, sample_shape=sample_shape)
      chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,))

    with self.subTest('sample_and_log_prob sample.shape'):
      sample, log_prob = dist.sample_and_log_prob(
          seed=self.seed, sample_shape=sample_shape)
      chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,))

    with self.subTest('sample_and_log_prob log_prob.shape'):
      sample, log_prob = dist.sample_and_log_prob(
          seed=self.seed, sample_shape=sample_shape)
      chex.assert_equal(log_prob.shape, sample_shape + batch_shape)

    with self.subTest('sample_and_log_prob log_prob value'):
      sample, log_prob = dist.sample_and_log_prob(
          seed=self.seed, sample_shape=sample_shape)
      np.testing.assert_allclose(log_prob, dist.log_prob(sample))
Ejemplo n.º 9
0
def test_numpyro_dict_priors_defaults_tfp():

    demo_priors = {
        "lengthscale": tfd.LogNormal(loc=0.0, scale=1.0),
        "variance": tfd.LogNormal(loc=0.0, scale=1.0),
        "obs_noise": tfd.LogNormal(loc=0.0, scale=1.0),
    }

    numpyro_params = numpyro_dict_params(demo_priors)

    assert set(numpyro_params) == set(demo_priors.keys())
    for ikey, iparam in demo_priors.items():
        # check keys exist for param
        assert set(numpyro_params[ikey].keys()) == set(("prior", "param_type"))
        # check init value is the same as initial value
        chex.assert_equal(numpyro_params[ikey]["prior"], iparam)
Ejemplo n.º 10
0
  def test_simple(self):
    """Check that the correct tags are removed."""
    mask = create_weight_decay_mask()
    data = {
        'bias': {
            'b': 4
        },
        'bias': {
            'BatchNorm_0': 4,
            'bias': 5,
            'a': 0
        },
        'BatchNorm_0': {
            'b': 4
        },
        'a': {
            'b': {
                'BatchNorm_0': 0,
                'bias': 0
            },
            'c': 0
        }
    }
    truth = {
        'bias': {
            'b': False
        },
        'bias': {
            'BatchNorm_0': False,
            'bias': False,
            'a': False
        },
        'BatchNorm_0': {
            'b': False
        },
        'a': {
            'b': {
                'BatchNorm_0': True,
                'bias': False
            },
            'c': True
        }
    }

    chex.assert_equal(mask(data), truth)
Ejemplo n.º 11
0
 def test_adam(self):
   init_fn, update_fn = optimizers.get_optimizer(
       ConfigDict({
           'optimizer': 'adam',
           'l2_decay_factor': None,
           'batch_size': 50,
           'total_accumulated_batch_size': 100,  # Use gradient accumulation.
           'opt_hparams': {
               'beta1': 0.9,
               'beta2': 0.999,
               'epsilon': 1e-7,
               'weight_decay': 0.0,
           }
       }))
   del update_fn
   optimizer_state = init_fn({'foo': jnp.ones(10)})
   # Test that we can extract 'count'.
   chex.assert_type(extract_field(optimizer_state, 'count'), int)
   # Test that we can extract 'nu'.
   chex.assert_shape(extract_field(optimizer_state, 'nu')['foo'], (10,))
   # Test that we can extract 'mu'.
   chex.assert_shape(extract_field(optimizer_state, 'mu')['foo'], (10,))
   # Test that attemptping to extract a nonexistent field "abc" returns None.
   chex.assert_equal(extract_field(optimizer_state, 'abc'), None)
Ejemplo n.º 12
0
def test_numpyro_add_priors_all(prior):

    gpjax_params = _get_conjugate_posterior_params()
    numpyro_params = numpyro_dict_params(gpjax_params)

    # add constraint
    new_numpyro_params = add_priors(numpyro_params, prior)
    for iparams in new_numpyro_params.values():

        # check if constraint in new dictionary
        chex.assert_equal(iparams["param_type"], "prior")
        chex.assert_equal(iparams["prior"], prior)

    # check we didn't modify original dictionary
    chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
Ejemplo n.º 13
0
def test_numpyro_add_priors_dict(variable, prior):

    gpjax_params = _get_conjugate_posterior_params()
    numpyro_params = numpyro_dict_params(gpjax_params)

    # create new dictionary
    new_param_dict = {str(variable): prior}

    # add constraint
    new_numpyro_params = add_priors(numpyro_params, new_param_dict)

    # check if constraint in new dictionary
    chex.assert_equal(new_numpyro_params[variable]["param_type"], "prior")
    chex.assert_equal(new_numpyro_params[variable]["prior"], prior)

    # check we didn't modify original dictionary
    chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
Ejemplo n.º 14
0
    def apply(self, x, config, num_classes, train=True):
        """Creates a model definition."""
        b, c = x.shape[0], x.shape[3]
        k = config.k
        sigma = config.ptopk_sigma
        num_samples = config.ptopk_num_samples

        sigma *= self.state("sigma_mutiplier",
                            shape=(),
                            initializer=nn.initializers.ones).value

        stats = {"x": x, "sigma": sigma}

        feature_extractor = models.ResNet50.shared(train=train,
                                                   name="ResNet_0")

        rpn_feature = feature_extractor(x)
        rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature),
                                            communication=Communication(
                                                config.communication),
                                            train=train)
        stats.update(rpn_stats)

        # rpn_scores are a list of score images. We keep track of the structure
        # because it is used in the aggregation step later-on.
        rpn_scores_shapes = [s.shape for s in rpn_scores]
        rpn_scores_flat = jnp.concatenate(
            [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1)
        top_k_indicators = sample_patches.select_patches_perturbed_topk(
            rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples)
        top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1])
        offset = 0
        weights = []
        for sh in rpn_scores_shapes:
            cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]]
            cur = jnp.reshape(cur, [b, k, sh[1], sh[2]])
            weights.append(cur)
            offset += sh[1] * sh[2]
        chex.assert_equal(offset, top_k_indicators.shape[-1])

        part_imgs = weighted_anchor_aggregator(x, weights)
        chex.assert_shape(part_imgs, (b * k, 224, 224, c))
        stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c])

        part_features = feature_extractor(part_imgs)
        part_features = jnp.mean(part_features,
                                 axis=[1, 2])  # GAP the spatial dims

        part_features = nn.dropout(  # features from parts
            jnp.reshape(part_features, [b * k, 2048]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())
        features = nn.dropout(  # features from whole image
            jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())

        # Mean pool all part features, add it to features and predict logits.
        concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]),
                              axis=1) + features
        concat_logits = nn.Dense(concat_out, num_classes)
        raw_logits = nn.Dense(features, num_classes)
        part_logits = jnp.reshape(nn.Dense(part_features, num_classes),
                                  [b, k, -1])

        all_logits = {
            "raw_logits": raw_logits,
            "concat_logits": concat_logits,
            "part_logits": part_logits,
        }
        # add entropy into it for entropy regularization.
        stats["rpn_scores_entropy"] = jax.scipy.special.entr(
            jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0)
        return all_logits, stats
Ejemplo n.º 15
0
 def inner(x, *args, **kwargs):
     out = f(x, *args, **kwargs)
     chex.assert_equal(x.dtype, out.dtype)
     return out
Ejemplo n.º 16
0
 def test_batch_shape(self):
     chex.assert_equal(self.wrapped_dist.batch_shape,
                       self.base_dist.batch_shape)
Ejemplo n.º 17
0
 def test_event_shape(self):
     chex.assert_equal(self.wrapped_dist.event_shape,
                       self.base_dist.event_shape)