def setUp(self): super().setUp() self._sample_shape = (np.int32(10),) self._seed = 42 self._key = jax.random.PRNGKey(self._seed) self.assertion_fn = lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL) self.base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.])) self.values = jnp.array([1., -1.]) self.distrax_second_dist = Normal(loc=-1., scale=0.8) self.tfp_second_dist = tfd.Normal(loc=-1., scale=0.8)
def setUp(self): super().setUp() self.base_dist = Transformed( distribution=Normal(loc=0., scale=1.), bijector=tfb.Exp()) self.values = jnp.array([0., 1., 2.]) self.distrax_second_dist = Transformed( distribution=Normal(loc=0.5, scale=0.8), bijector=tfb.Exp()) self.tfp_second_dist = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0.5, scale=0.8), bijector=tfb.Exp())
def test_loc_attr_of_normal(self): dist = Normal(loc=0., scale=1.) wrapped_dist = conversion.as_distribution(dist) assert isinstance(wrapped_dist, Normal) self.assertIs(wrapped_dist, dist) # Access the `loc` attribute of a wrapped Normal. np.testing.assert_almost_equal(wrapped_dist.loc, 0.)
def test_attrs_of_transformed_distribution(self): dist = Transformed(Normal(loc=0., scale=1.), bijector=lambda x: x) wrapped_dist = conversion.as_distribution(dist) assert isinstance(wrapped_dist, Transformed) self.assertIs(wrapped_dist, dist) # Access the `distribution` attribute of a wrapped Transformed. assert isinstance(wrapped_dist.distribution, Normal) # Access the `loc` attribute of a transformed Normal within a wrapped # Transformed. np.testing.assert_almost_equal(wrapped_dist.distribution.loc, 0.)
def score_kl(self, inputs): # forward propagation z = self.forward(inputs) # calculate latent probability loss = self.base_dist.kl_divergence( Normal(loc=z.mean(axis=0), scale=z.std(axis=0)) ) return jnp.mean(loss)
def test_with_independent(self): base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.])) wrapped_dist = tfp_compatible_distribution(base_dist) meta_dist = tfd.Independent(wrapped_dist, 1, validate_args=True) samples = meta_dist.sample((), self._key) log_prob = meta_dist.log_prob(samples) distrax_meta_dist = Independent(base_dist, 1) expected_log_prob = distrax_meta_dist.log_prob(samples) self.assertion_fn(log_prob, expected_log_prob)
def test_with_transformed_distribution(self): base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.])) wrapped_dist = tfp_compatible_distribution(base_dist) meta_dist = tfd.TransformedDistribution( distribution=wrapped_dist, bijector=tfb.Exp(), validate_args=True) samples = meta_dist.sample(seed=self._key) log_prob = meta_dist.log_prob(samples) distrax_meta_dist = Transformed( distribution=base_dist, bijector=tfb.Exp()) expected_log_prob = distrax_meta_dist.log_prob(samples) self.assertion_fn(log_prob, expected_log_prob)
def test_on_distrax_distribution(self): dist = Normal(loc=0., scale=1.) wrapped_dist = conversion.to_tfp(dist) assert isinstance(wrapped_dist, Normal) # Access the `loc` attribute of a wrapped Normal. np.testing.assert_almost_equal(wrapped_dist.loc, 0.)
def train_max_layers_model( X: Array, rbig_block: RBIGBlockInit, max_layers: int = 50, verbose: bool = False, interval: int = 10, ): """Simple training procedure using the iterative scheme. Uses a `max_layers` argument for the stopping criteria Parameters ---------- X : Array the input data to be trained rbig_block : RBIGBlock a dataclass to be used max_layers : int, default=50 the maximum number of layers to train the model verbose : bool whether to show interval : int how often to produce numbers Returns ------- X_g : Array the transformed variable bijectors : List[Bijectors] a list of the bijectors which have been trained final_loss : Array an array for the loss function """ # init rbig_block X_g = X.copy() n_features = X.shape[1] ilayer = 0 n_bijectors = len(rbig_block.init_functions) bijectors = list() t0 = time.time() while ilayer < max_layers: # fit RBIG block X_g, ibijector = rbig_block.forward_and_bijector(X_g) # append bijectors bijectors += ibijector ilayer += 1 if verbose: if ilayer % interval == 0: print(f"Layer {ilayer} - Elapsed Time: {time.time()-t0:.4f}") if verbose: print("Completed.") print( f"Final Number of layers: {ilayer} (Blocks: {ilayer//n_bijectors})" ) print(f"Elapsed Time: {time.time()-t0:.4f}") # ================================ # Create Gaussianization model # ================================ # create base distribution base_dist = Normal(jnp.zeros((n_features, )), jnp.ones((n_features, ))) # create gaussianization flow model rbig_model = GaussianizationFlow(base_dist=base_dist, bijectors=bijectors) # return relevant stuff return X_g, rbig_model
final_loss, bijectors = _remove_layers( info_loss=final_loss, bijectors=bijectors, n_bijectors=n_bijectors, n_layers_remove=n_layers_remove, ) t1 = time.time() # ================================ # Create Gaussianization model # ================================ # create base distribution base_dist = Normal(jnp.zeros((n_features, )), jnp.ones((n_features, ))) # create gaussianization flow model rbig_model = RBIGFlow(base_dist=base_dist, bijectors=bijectors, info_loss=final_loss) if verbose: print( f"Final Number of layers: {final_loss.shape[0]} (Blocks: {final_loss.shape[0]//n_bijectors})" ) print(f"Total Time: {t1-t0:.4f} secs") # return relevant stuff return X_g, rbig_model
def init_default_gf_model( shape: tuple, X: Array = None, n_blocks: int = 4, n_components: int = 20, mixture: str = "logistic", init_mixcdf: str = "gmm", init_rotation: str = "pca", inverse_cdf: str = "logistic", n_reflections: int = 10, return_transform: bool = False, ): n_features = shape[0] rng = jax.random.PRNGKey(42) # rng, _ = jax.random.split(jax.random.PRNGKey(123), 2) if mixture == "logistic": init_mixcdf_f = InitMixtureLogisticCDF( n_components=n_components, init_method=init_mixcdf ) elif mixture == "gaussian": init_mixcdf_f = InitMixtureGaussianCDF( n_components=n_components, init_method=init_mixcdf ) else: raise ValueError(f"Unrecognized mixture dist: {mixture}") if inverse_cdf == "logistic": # Logit Transform init_icdf_f = InitLogitTransform() elif inverse_cdf == "logistictemp": # Logit Transform init_icdf_f = InitLogitTempTransform() elif inverse_cdf == "gaussian": init_icdf_f = InitInverseGaussCDF() else: raise ValueError(f"Unrecognized inverse cdf function: {inverse_cdf}") # ===================== # HouseHolder Transform # ====================== n_reflections = n_reflections # initialize init function init_hh_f = InitHouseHolder(n_reflections=n_reflections, method=init_rotation) block_rngs = jax.random.split(rng, num=n_blocks) # rng = jax.random.split(jax.random.PRNGKey(42), n_blocks) # block_rngs = jax.random.split(jax.random.PRNGKey(42), n_blocks) itercount = itertools.count() bijectors = [] X_g = X.copy() pbar = tqdm.tqdm(block_rngs) with pbar: for iblock, irng in enumerate(pbar): pbar.set_description( f"Initializing - Block: {iblock+1} | Layer {next(itercount)}" ) # ====================== # MIXTURECDF # ====================== # create keys for all inits irng, icdf_rng = jax.random.split(irng, 2) # intialize bijector and transformation X_g, layer = init_mixcdf_f.transform_and_bijector( inputs=X_g, rng=icdf_rng, n_features=n_features ) # add bijector to list bijectors.append(layer) # ====================== # LOGIT # ====================== pbar.set_description( f"Initializing - Block: {iblock+1} | Layer {next(itercount)}" ) # intialize bijector and transformation X_g, layer = init_icdf_f.transform_and_bijector(inputs=X_g) bijectors.append(layer) # ====================== # HOUSEHOLDER # ====================== pbar.set_description( f"Initializing - Block: {iblock+1} | Layer {next(itercount)}" ) # create keys for all inits irng, hh_rng = jax.random.split(irng, 2) # intialize bijector and transformation X_g, layer = init_hh_f.transform_and_bijector( inputs=X_g, rng=hh_rng, n_features=n_features ) bijectors.append(layer) # create base dist base_dist = Normal(jnp.zeros((n_features,)), jnp.ones((n_features,))) # create flow model gf_model = GaussianizationFlow(base_dist=base_dist, bijectors=bijectors) if return_transform: return X_g, gf_model else: return gf_model
def init_gf_composite_spline_model( shape: tuple, X: Array = None, n_blocks: int = 4, n_bins: int = 20, range_min: float = 0.0, range_max: float = 1.0, init_rotation: str = "random", n_reflections: int = 10, squeeze: str = "sigmoid", return_transform: bool = False, **kwargs, ): n_features = shape[0] rng = jax.random.PRNGKey(42) # rng, _ = jax.random.split(jax.random.PRNGKey(123), 2) # ===================== # Composite Transform # ====================== init_nl_forward_f = InitGaussCDF() init_nl_inverse_f = InitInverseGaussCDF() # ===================== # RQ Spline # ====================== init_rq_f = InitPiecewiseRationalQuadraticCDF( n_bins=n_bins, range_min=range_min, range_max=range_max, **kwargs ) # ===================== # HouseHolder Transform # ====================== n_reflections = n_reflections # initialize init function init_hh_f = InitHouseHolder(n_reflections=n_reflections, method=init_rotation) block_rngs = jax.random.split(rng, num=n_blocks) # rng = jax.random.split(jax.random.PRNGKey(42), n_blocks) # block_rngs = jax.random.split(jax.random.PRNGKey(42), n_blocks) itercount = itertools.count() bijectors = [] X_g = X.copy() pbar = tqdm.tqdm(block_rngs) with pbar: for iblock, irng in enumerate(pbar): pbar.set_description( f"Initializing - Block: {iblock+1} | Layer {next(itercount)}" ) # ====================== # Forward Squeezing Transform # ====================== # intialize bijector and transformation X_g, layer = init_nl_forward_f.transform_and_bijector(inputs=X_g,) # add bijector to list bijectors.append(layer) # ====================== # RQ Spline # ====================== # create keys for all inits irng, irq_rng = jax.random.split(irng, 2) # intialize bijector and transformation X_g, layer = init_rq_f.transform_and_bijector( inputs=X_g, rng=irq_rng, shape=X.shape[1:] ) # add bijector to list bijectors.append(layer) # ====================== # Inverse Squeezing Transform # ====================== # intialize bijector and transformation X_g, layer = init_nl_inverse_f.transform_and_bijector(inputs=X_g,) # add bijector to list bijectors.append(layer) # ====================== # HOUSEHOLDER # ====================== pbar.set_description( f"Initializing - Block: {iblock+1} | Layer {next(itercount)}" ) # create keys for all inits irng, hh_rng = jax.random.split(irng, 2) # intialize bijector and transformation X_g, layer = init_hh_f.transform_and_bijector( inputs=X_g, rng=hh_rng, n_features=n_features ) bijectors.append(layer) # create base dist base_dist = Normal(jnp.zeros((n_features,)), jnp.ones((n_features,))) # create flow model gf_model = GaussianizationFlow(base_dist=base_dist, bijectors=bijectors) if return_transform: return X_g, gf_model else: return gf_model
class TFPCompatibleDistributionNormal(parameterized.TestCase): """Tests for Normal distribution.""" def setUp(self): super().setUp() self._sample_shape = (np.int32(10),) self._seed = 42 self._key = jax.random.PRNGKey(self._seed) self.assertion_fn = lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL) self.base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.])) self.values = jnp.array([1., -1.]) self.distrax_second_dist = Normal(loc=-1., scale=0.8) self.tfp_second_dist = tfd.Normal(loc=-1., scale=0.8) @property def wrapped_dist(self): return tfp_compatible_distribution(self.base_dist) def test_event_shape(self): chex.assert_equal(self.wrapped_dist.event_shape, self.base_dist.event_shape) def test_batch_shape(self): chex.assert_equal(self.wrapped_dist.batch_shape, self.base_dist.batch_shape) @chex.all_variants def test_sample(self): def sample_fn(key): return self.wrapped_dist.sample(seed=key, sample_shape=self._sample_shape) sample_fn = self.variant(sample_fn) self.assertion_fn( sample_fn(self._key), self.base_dist.sample(sample_shape=self._sample_shape, seed=self._key)) @chex.all_variants(with_pmap=False) @parameterized.named_parameters( ('mean', 'mean'), ('mode', 'mode'), ('median', 'median'), ('stddev', 'stddev'), ('variance', 'variance'), ('entropy', 'entropy'), ) def test_method(self, method): try: expected_result = self.variant(getattr(self.base_dist, method))() except NotImplementedError: return except AttributeError: return result = self.variant(getattr(self.wrapped_dist, method))() self.assertion_fn(result, expected_result) @chex.all_variants @parameterized.named_parameters( ('log_prob', 'log_prob'), ('prob', 'prob'), ('log_cdf', 'log_cdf'), ('cdf', 'cdf'), ) def test_method_with_value(self, method): try: expected_result = self.variant( getattr(self.base_dist, method))(self.values) except NotImplementedError: return except AttributeError: return result = self.variant(getattr(self.wrapped_dist, method))(self.values) self.assertion_fn(result, expected_result) @chex.all_variants @parameterized.named_parameters( ('kl_divergence', 'kl_divergence'), ('cross_entropy', 'cross_entropy'), ) def test_with_two_distributions(self, method): """Test methods of the form listed below. D(distrax_distrib || wrapped_distrib), D(wrapped_distrib || distrax_distrib), D(tfp_distrib || wrapped_distrib), D(wrapped_distrib || tfp_distrib). Args: method: the method name to be tested """ try: expected_result1 = self.variant( getattr(self.distrax_second_dist, method))(self.base_distribution) expected_result2 = self.variant( getattr(self.base_distribution, method))(self.distrax_second_dist) except NotImplementedError: return except AttributeError: return distrax_result1 = self.variant(getattr(self.distrax_second_dist, method))( self.wrapped_dist) distrax_result2 = self.variant(getattr(self.wrapped_dist, method))( self.distrax_second_dist) tfp_result1 = self.variant(getattr(self.tfp_second_dist, method))( self.wrapped_dist) tfp_result2 = self.variant(getattr(self.wrapped_dist, method))( self.tfp_second_dist) self.assertion_fn(distrax_result1, expected_result1) self.assertion_fn(distrax_result2, expected_result2) self.assertion_fn(tfp_result1, expected_result1) self.assertion_fn(tfp_result2, expected_result2)
def test_with_sample(self): base_dist = Normal(0., 1.) wrapped_dist = tfp_compatible_distribution(base_dist) meta_dist = tfd.Sample( wrapped_dist, sample_shape=[1, 3], validate_args=True) meta_dist.log_prob(meta_dist.sample(2, seed=self._key))
def _make_components(self, key_loc, key_scale): components_shape = self.batch_shape + (self.num_components, ) return Normal( loc=jax.random.normal(key=key_loc, shape=components_shape), scale=jax.random.uniform(key=key_scale, shape=components_shape) + 0.5)