def test_reshape_shape_forward(filters): n_dims = (1, 28, 28, 1) new_shape = _get_new_shapes(28, 28, 1, filters) params_rng, data_rng = jax.random.split(KEY, 2) x = jax.random.uniform(data_rng, shape=n_dims) # create layer init_func = Squeeze(filter_shape=filters, collapse=None, return_outputs=True) # create layer z_, params, forward_f, inverse_f = init_func(rng=params_rng, shape=n_dims, inputs=x) # forward transformation z, log_abs_det = forward_f(params, x) # checks chex.assert_tree_all_close(z, z_) chex.assert_equal_shape([z, log_abs_det, z_]) chex.assert_rank(z, 4) chex.assert_equal(z.shape[1:], new_shape) # inverse transformation x_approx, log_abs_det = inverse_f(params, z) # checks chex.assert_equal_shape([x_approx, x]) chex.assert_tree_all_close(x_approx, x)
def test_numpyro_marginal_ll_tfp_priors_type(n_samples, n_features, n_latents, dtype): # create sample data ds = _gen_training_data(n_samples, n_features, n_latents) # convert to tyle ds = jax.tree_util.tree_map(lambda x: x.astype(dtype), ds) # initialize parameters params, posterior = _get_conjugate_posterior_params() # convert to numpyro-style params numpyro_params = numpyro_dict_params(params) # convert to priors numpyro_params = add_priors(numpyro_params, tfd.LogNormal(0.0, 10.0)) # initialize numpyro-style GP model npy_model = numpyro_marginal_ll(posterior, numpyro_params) # do one forward pass with context with numpyro.handlers.seed(rng_seed=KEY): pred = npy_model(ds) chex.assert_equal(pred.dtype, ds.y.dtype)
def test_simple(self): """Check if the leaf key is `a`.""" mask = create_mask(lambda path, _: path[-1] == 'a') data = {'a': 4, 'b': {'a': 5, 'c': 1}, 'c': {'a': {'b': 1}}} truth = {'a': True, 'b': {'a': True, 'c': False}, 'c': {'a': {'b': False}}} chex.assert_equal(mask(data), truth)
def test_with_different_event_ndims(self): dx_bij = Lambda(forward=lambda x: x.reshape(x.shape[:-1] + (2, 3)), inverse=lambda y: y.reshape(y.shape[:-2] + (6, )), forward_log_det_jacobian=lambda _: 0, inverse_log_det_jacobian=lambda _: 0, is_constant_jacobian=True, event_ndims_in=1, event_ndims_out=2) tfp_bij = tfp_compatible_bijector(dx_bij) with self.subTest('forward_event_ndims'): assert tfp_bij.forward_event_ndims(1) == 2 assert tfp_bij.forward_event_ndims(2) == 3 with self.subTest('inverse_event_ndims'): assert tfp_bij.inverse_event_ndims(2) == 1 assert tfp_bij.inverse_event_ndims(3) == 2 with self.subTest('forward_event_ndims with incorrect input'): with self.assertRaises(ValueError): tfp_bij.forward_event_ndims(0) with self.subTest('inverse_event_ndims with incorrect input'): with self.assertRaises(ValueError): tfp_bij.inverse_event_ndims(0) with self.assertRaises(ValueError): tfp_bij.inverse_event_ndims(1) with self.subTest('forward_event_shape'): y_shape = tfp_bij.forward_event_shape((6, )) y_shape_tensor = tfp_bij.forward_event_shape_tensor((6, )) self.assertEqual(y_shape, (2, 3)) np.testing.assert_array_equal(y_shape_tensor, jnp.array((2, 3))) with self.subTest('inverse_event_shape'): x_shape = tfp_bij.inverse_event_shape((2, 3)) x_shape_tensor = tfp_bij.inverse_event_shape_tensor((2, 3)) self.assertEqual(x_shape, (6, )) np.testing.assert_array_equal(x_shape_tensor, jnp.array((6, ))) with self.subTest('TransformedDistribution with correct event_ndims'): base = tfd.MultivariateNormalDiag(np.zeros(6), np.ones(6)) dist = tfd.TransformedDistribution(base, tfp_bij) chex.assert_equal(dist.event_shape, (2, 3)) sample = dist.sample(seed=jax.random.PRNGKey(0)) chex.assert_shape(sample, (2, 3)) log_prob = dist.log_prob(sample) chex.assert_shape(log_prob, ()) with self.subTest( 'TransformedDistribution with incorrect event_ndims'): base = tfd.Normal(np.zeros(6), np.ones(6)) dist = tfd.TransformedDistribution(base, tfp_bij) with self.assertRaises(ValueError): _ = dist.event_shape
def test_numpyro_add_constraints_str(variable, constraint): gpjax_params = _get_conjugate_posterior_params() numpyro_params = numpyro_dict_params(gpjax_params) # add constraint new_numpyro_params = add_constraints(numpyro_params, variable, constraint) # check if constraint in new dictionary chex.assert_equal(new_numpyro_params[variable]["constraint"], constraint) # check we didn't modify original dictionary chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def test_numpyro_dict_params_defaults_float(): demo_params = { "lengthscale": 1.0, "variance": 1.0, "obs_noise": 1.0, } numpyro_params = numpyro_dict_params(demo_params) assert set(numpyro_params) == set(demo_params.keys()) for ikey, iparam in demo_params.items(): # check keys exist for param assert set(numpyro_params[ikey].keys()) == set( ("init_value", "constraint", "param_type")) # check init value is the same as initial value chex.assert_equal(numpyro_params[ikey]["init_value"], iparam) # check default constraint is positive chex.assert_equal(numpyro_params[ikey]["constraint"], constraints.positive) # check if param type is param chex.assert_equal(numpyro_params[ikey]["param_type"], "param") # check we didn't modify original dictionary chex.assert_equal( demo_params, { "lengthscale": 1.0, "variance": 1.0, "obs_noise": 1.0, }, )
def test_numpyro_add_constraints_all(constraint): gpjax_params = _get_conjugate_posterior_params() numpyro_params = numpyro_dict_params(gpjax_params) # add constraint new_numpyro_params = add_constraints(numpyro_params, constraint) for iparams in new_numpyro_params.values(): # check if constraint in new dictionary chex.assert_equal(iparams["constraint"], constraint) # check we didn't modify original dictionary chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def test_batched_bijector_shapes(self, batch_shape, sample_shape): base = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3)) bijector = block.Block(tfb.Scale(jnp.ones(batch_shape + (3,))), 1) dist = transformed.Transformed(base, bijector) with self.subTest('batch_shape'): chex.assert_equal(dist.batch_shape, batch_shape) with self.subTest('sample.shape'): sample = dist.sample(seed=self.seed, sample_shape=sample_shape) chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,)) with self.subTest('sample_and_log_prob sample.shape'): sample, log_prob = dist.sample_and_log_prob( seed=self.seed, sample_shape=sample_shape) chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,)) with self.subTest('sample_and_log_prob log_prob.shape'): sample, log_prob = dist.sample_and_log_prob( seed=self.seed, sample_shape=sample_shape) chex.assert_equal(log_prob.shape, sample_shape + batch_shape) with self.subTest('sample_and_log_prob log_prob value'): sample, log_prob = dist.sample_and_log_prob( seed=self.seed, sample_shape=sample_shape) np.testing.assert_allclose(log_prob, dist.log_prob(sample))
def test_numpyro_dict_priors_defaults_tfp(): demo_priors = { "lengthscale": tfd.LogNormal(loc=0.0, scale=1.0), "variance": tfd.LogNormal(loc=0.0, scale=1.0), "obs_noise": tfd.LogNormal(loc=0.0, scale=1.0), } numpyro_params = numpyro_dict_params(demo_priors) assert set(numpyro_params) == set(demo_priors.keys()) for ikey, iparam in demo_priors.items(): # check keys exist for param assert set(numpyro_params[ikey].keys()) == set(("prior", "param_type")) # check init value is the same as initial value chex.assert_equal(numpyro_params[ikey]["prior"], iparam)
def test_simple(self): """Check that the correct tags are removed.""" mask = create_weight_decay_mask() data = { 'bias': { 'b': 4 }, 'bias': { 'BatchNorm_0': 4, 'bias': 5, 'a': 0 }, 'BatchNorm_0': { 'b': 4 }, 'a': { 'b': { 'BatchNorm_0': 0, 'bias': 0 }, 'c': 0 } } truth = { 'bias': { 'b': False }, 'bias': { 'BatchNorm_0': False, 'bias': False, 'a': False }, 'BatchNorm_0': { 'b': False }, 'a': { 'b': { 'BatchNorm_0': True, 'bias': False }, 'c': True } } chex.assert_equal(mask(data), truth)
def test_adam(self): init_fn, update_fn = optimizers.get_optimizer( ConfigDict({ 'optimizer': 'adam', 'l2_decay_factor': None, 'batch_size': 50, 'total_accumulated_batch_size': 100, # Use gradient accumulation. 'opt_hparams': { 'beta1': 0.9, 'beta2': 0.999, 'epsilon': 1e-7, 'weight_decay': 0.0, } })) del update_fn optimizer_state = init_fn({'foo': jnp.ones(10)}) # Test that we can extract 'count'. chex.assert_type(extract_field(optimizer_state, 'count'), int) # Test that we can extract 'nu'. chex.assert_shape(extract_field(optimizer_state, 'nu')['foo'], (10,)) # Test that we can extract 'mu'. chex.assert_shape(extract_field(optimizer_state, 'mu')['foo'], (10,)) # Test that attemptping to extract a nonexistent field "abc" returns None. chex.assert_equal(extract_field(optimizer_state, 'abc'), None)
def test_numpyro_add_priors_all(prior): gpjax_params = _get_conjugate_posterior_params() numpyro_params = numpyro_dict_params(gpjax_params) # add constraint new_numpyro_params = add_priors(numpyro_params, prior) for iparams in new_numpyro_params.values(): # check if constraint in new dictionary chex.assert_equal(iparams["param_type"], "prior") chex.assert_equal(iparams["prior"], prior) # check we didn't modify original dictionary chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def test_numpyro_add_priors_dict(variable, prior): gpjax_params = _get_conjugate_posterior_params() numpyro_params = numpyro_dict_params(gpjax_params) # create new dictionary new_param_dict = {str(variable): prior} # add constraint new_numpyro_params = add_priors(numpyro_params, new_param_dict) # check if constraint in new dictionary chex.assert_equal(new_numpyro_params[variable]["param_type"], "prior") chex.assert_equal(new_numpyro_params[variable]["prior"], prior) # check we didn't modify original dictionary chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def apply(self, x, config, num_classes, train=True): """Creates a model definition.""" b, c = x.shape[0], x.shape[3] k = config.k sigma = config.ptopk_sigma num_samples = config.ptopk_num_samples sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats = {"x": x, "sigma": sigma} feature_extractor = models.ResNet50.shared(train=train, name="ResNet_0") rpn_feature = feature_extractor(x) rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature), communication=Communication( config.communication), train=train) stats.update(rpn_stats) # rpn_scores are a list of score images. We keep track of the structure # because it is used in the aggregation step later-on. rpn_scores_shapes = [s.shape for s in rpn_scores] rpn_scores_flat = jnp.concatenate( [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1) top_k_indicators = sample_patches.select_patches_perturbed_topk( rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples) top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1]) offset = 0 weights = [] for sh in rpn_scores_shapes: cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]] cur = jnp.reshape(cur, [b, k, sh[1], sh[2]]) weights.append(cur) offset += sh[1] * sh[2] chex.assert_equal(offset, top_k_indicators.shape[-1]) part_imgs = weighted_anchor_aggregator(x, weights) chex.assert_shape(part_imgs, (b * k, 224, 224, c)) stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c]) part_features = feature_extractor(part_imgs) part_features = jnp.mean(part_features, axis=[1, 2]) # GAP the spatial dims part_features = nn.dropout( # features from parts jnp.reshape(part_features, [b * k, 2048]), 0.5, deterministic=not train, rng=nn.make_rng()) features = nn.dropout( # features from whole image jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]), 0.5, deterministic=not train, rng=nn.make_rng()) # Mean pool all part features, add it to features and predict logits. concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]), axis=1) + features concat_logits = nn.Dense(concat_out, num_classes) raw_logits = nn.Dense(features, num_classes) part_logits = jnp.reshape(nn.Dense(part_features, num_classes), [b, k, -1]) all_logits = { "raw_logits": raw_logits, "concat_logits": concat_logits, "part_logits": part_logits, } # add entropy into it for entropy regularization. stats["rpn_scores_entropy"] = jax.scipy.special.entr( jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0) return all_logits, stats
def inner(x, *args, **kwargs): out = f(x, *args, **kwargs) chex.assert_equal(x.dtype, out.dtype) return out
def test_batch_shape(self): chex.assert_equal(self.wrapped_dist.batch_shape, self.base_dist.batch_shape)
def test_event_shape(self): chex.assert_equal(self.wrapped_dist.event_shape, self.base_dist.event_shape)