def setUp(self):
   super().setUp()
   self._sample_shape = (np.int32(10),)
   self._seed = 42
   self._key = jax.random.PRNGKey(self._seed)
   self.assertion_fn = lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL)
   self.base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.]))
   self.values = jnp.array([1., -1.])
   self.distrax_second_dist = Normal(loc=-1., scale=0.8)
   self.tfp_second_dist = tfd.Normal(loc=-1., scale=0.8)
 def setUp(self):
   super().setUp()
   self.base_dist = Transformed(
       distribution=Normal(loc=0., scale=1.),
       bijector=tfb.Exp())
   self.values = jnp.array([0., 1., 2.])
   self.distrax_second_dist = Transformed(
       distribution=Normal(loc=0.5, scale=0.8),
       bijector=tfb.Exp())
   self.tfp_second_dist = tfd.TransformedDistribution(
       distribution=tfd.Normal(loc=0.5, scale=0.8),
       bijector=tfb.Exp())
Esempio n. 3
0
 def test_loc_attr_of_normal(self):
     dist = Normal(loc=0., scale=1.)
     wrapped_dist = conversion.as_distribution(dist)
     assert isinstance(wrapped_dist, Normal)
     self.assertIs(wrapped_dist, dist)
     # Access the `loc` attribute of a wrapped Normal.
     np.testing.assert_almost_equal(wrapped_dist.loc, 0.)
Esempio n. 4
0
 def test_attrs_of_transformed_distribution(self):
     dist = Transformed(Normal(loc=0., scale=1.), bijector=lambda x: x)
     wrapped_dist = conversion.as_distribution(dist)
     assert isinstance(wrapped_dist, Transformed)
     self.assertIs(wrapped_dist, dist)
     # Access the `distribution` attribute of a wrapped Transformed.
     assert isinstance(wrapped_dist.distribution, Normal)
     # Access the `loc` attribute of a transformed Normal within a wrapped
     # Transformed.
     np.testing.assert_almost_equal(wrapped_dist.distribution.loc, 0.)
Esempio n. 5
0
    def score_kl(self, inputs):

        # forward propagation
        z = self.forward(inputs)

        # calculate latent probability
        loss = self.base_dist.kl_divergence(
            Normal(loc=z.mean(axis=0), scale=z.std(axis=0))
        )

        return jnp.mean(loss)
  def test_with_independent(self):
    base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.]))
    wrapped_dist = tfp_compatible_distribution(base_dist)

    meta_dist = tfd.Independent(wrapped_dist, 1, validate_args=True)
    samples = meta_dist.sample((), self._key)
    log_prob = meta_dist.log_prob(samples)

    distrax_meta_dist = Independent(base_dist, 1)
    expected_log_prob = distrax_meta_dist.log_prob(samples)

    self.assertion_fn(log_prob, expected_log_prob)
  def test_with_transformed_distribution(self):
    base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.]))
    wrapped_dist = tfp_compatible_distribution(base_dist)

    meta_dist = tfd.TransformedDistribution(
        distribution=wrapped_dist, bijector=tfb.Exp(), validate_args=True)
    samples = meta_dist.sample(seed=self._key)
    log_prob = meta_dist.log_prob(samples)

    distrax_meta_dist = Transformed(
        distribution=base_dist, bijector=tfb.Exp())
    expected_log_prob = distrax_meta_dist.log_prob(samples)

    self.assertion_fn(log_prob, expected_log_prob)
Esempio n. 8
0
 def test_on_distrax_distribution(self):
     dist = Normal(loc=0., scale=1.)
     wrapped_dist = conversion.to_tfp(dist)
     assert isinstance(wrapped_dist, Normal)
     # Access the `loc` attribute of a wrapped Normal.
     np.testing.assert_almost_equal(wrapped_dist.loc, 0.)
Esempio n. 9
0
def train_max_layers_model(
    X: Array,
    rbig_block: RBIGBlockInit,
    max_layers: int = 50,
    verbose: bool = False,
    interval: int = 10,
):
    """Simple training procedure using the iterative scheme.
    Uses a `max_layers` argument for the stopping criteria
    
    Parameters
    ----------
    X : Array
        the input data to be trained
    rbig_block : RBIGBlock
        a dataclass to be used
    max_layers : int, default=50
        the maximum number of layers to train the model
    verbose : bool
        whether to show
    interval : int
        how often to produce numbers
    
    Returns
    -------
    X_g : Array
        the transformed variable
    bijectors : List[Bijectors]
        a list of the bijectors which have been trained
    final_loss : Array
        an array for the loss function
    """
    # init rbig_block
    X_g = X.copy()
    n_features = X.shape[1]
    ilayer = 0

    n_bijectors = len(rbig_block.init_functions)
    bijectors = list()
    t0 = time.time()
    while ilayer < max_layers:

        # fit RBIG block
        X_g, ibijector = rbig_block.forward_and_bijector(X_g)

        # append bijectors
        bijectors += ibijector

        ilayer += 1

        if verbose:
            if ilayer % interval == 0:
                print(f"Layer {ilayer} - Elapsed Time: {time.time()-t0:.4f}")

    if verbose:
        print("Completed.")
        print(
            f"Final Number of layers: {ilayer} (Blocks: {ilayer//n_bijectors})"
        )
        print(f"Elapsed Time: {time.time()-t0:.4f}")
    # ================================
    # Create Gaussianization model
    # ================================

    # create base distribution
    base_dist = Normal(jnp.zeros((n_features, )), jnp.ones((n_features, )))

    # create gaussianization flow model
    rbig_model = GaussianizationFlow(base_dist=base_dist, bijectors=bijectors)

    # return relevant stuff
    return X_g, rbig_model
Esempio n. 10
0
    final_loss, bijectors = _remove_layers(
        info_loss=final_loss,
        bijectors=bijectors,
        n_bijectors=n_bijectors,
        n_layers_remove=n_layers_remove,
    )

    t1 = time.time()

    # ================================
    # Create Gaussianization model
    # ================================

    # create base distribution
    base_dist = Normal(jnp.zeros((n_features, )), jnp.ones((n_features, )))

    # create gaussianization flow model
    rbig_model = RBIGFlow(base_dist=base_dist,
                          bijectors=bijectors,
                          info_loss=final_loss)

    if verbose:
        print(
            f"Final Number of layers: {final_loss.shape[0]} (Blocks: {final_loss.shape[0]//n_bijectors})"
        )
        print(f"Total Time: {t1-t0:.4f} secs")

    # return relevant stuff
    return X_g, rbig_model
Esempio n. 11
0
def init_default_gf_model(
    shape: tuple,
    X: Array = None,
    n_blocks: int = 4,
    n_components: int = 20,
    mixture: str = "logistic",
    init_mixcdf: str = "gmm",
    init_rotation: str = "pca",
    inverse_cdf: str = "logistic",
    n_reflections: int = 10,
    return_transform: bool = False,
):

    n_features = shape[0]
    rng = jax.random.PRNGKey(42)
    # rng, _ = jax.random.split(jax.random.PRNGKey(123), 2)

    if mixture == "logistic":
        init_mixcdf_f = InitMixtureLogisticCDF(
            n_components=n_components, init_method=init_mixcdf
        )

    elif mixture == "gaussian":
        init_mixcdf_f = InitMixtureGaussianCDF(
            n_components=n_components, init_method=init_mixcdf
        )
    else:
        raise ValueError(f"Unrecognized mixture dist: {mixture}")

    if inverse_cdf == "logistic":
        # Logit Transform
        init_icdf_f = InitLogitTransform()
    elif inverse_cdf == "logistictemp":
        # Logit Transform
        init_icdf_f = InitLogitTempTransform()
    elif inverse_cdf == "gaussian":
        init_icdf_f = InitInverseGaussCDF()
    else:
        raise ValueError(f"Unrecognized inverse cdf function: {inverse_cdf}")
    # =====================
    # HouseHolder Transform
    # ======================
    n_reflections = n_reflections
    # initialize init function
    init_hh_f = InitHouseHolder(n_reflections=n_reflections, method=init_rotation)

    block_rngs = jax.random.split(rng, num=n_blocks)
    # rng = jax.random.split(jax.random.PRNGKey(42), n_blocks)
    # block_rngs = jax.random.split(jax.random.PRNGKey(42), n_blocks)

    itercount = itertools.count()
    bijectors = []

    X_g = X.copy()

    pbar = tqdm.tqdm(block_rngs)
    with pbar:
        for iblock, irng in enumerate(pbar):

            pbar.set_description(
                f"Initializing - Block: {iblock+1} | Layer {next(itercount)}"
            )

            # ======================
            # MIXTURECDF
            # ======================
            # create keys for all inits
            irng, icdf_rng = jax.random.split(irng, 2)

            # intialize bijector and transformation
            X_g, layer = init_mixcdf_f.transform_and_bijector(
                inputs=X_g, rng=icdf_rng, n_features=n_features
            )

            # add bijector to list
            bijectors.append(layer)

            # ======================
            # LOGIT
            # ======================

            pbar.set_description(
                f"Initializing - Block: {iblock+1} | Layer {next(itercount)}"
            )

            # intialize bijector and transformation
            X_g, layer = init_icdf_f.transform_and_bijector(inputs=X_g)

            bijectors.append(layer)

            # ======================
            # HOUSEHOLDER
            # ======================
            pbar.set_description(
                f"Initializing - Block: {iblock+1} | Layer {next(itercount)}"
            )
            # create keys for all inits
            irng, hh_rng = jax.random.split(irng, 2)

            # intialize bijector and transformation
            X_g, layer = init_hh_f.transform_and_bijector(
                inputs=X_g, rng=hh_rng, n_features=n_features
            )

            bijectors.append(layer)

    # create base dist
    base_dist = Normal(jnp.zeros((n_features,)), jnp.ones((n_features,)))

    # create flow model
    gf_model = GaussianizationFlow(base_dist=base_dist, bijectors=bijectors)
    if return_transform:
        return X_g, gf_model
    else:
        return gf_model
Esempio n. 12
0
def init_gf_composite_spline_model(
    shape: tuple,
    X: Array = None,
    n_blocks: int = 4,
    n_bins: int = 20,
    range_min: float = 0.0,
    range_max: float = 1.0,
    init_rotation: str = "random",
    n_reflections: int = 10,
    squeeze: str = "sigmoid",
    return_transform: bool = False,
    **kwargs,
):

    n_features = shape[0]
    rng = jax.random.PRNGKey(42)
    # rng, _ = jax.random.split(jax.random.PRNGKey(123), 2)
    # =====================
    # Composite Transform
    # ======================
    init_nl_forward_f = InitGaussCDF()
    init_nl_inverse_f = InitInverseGaussCDF()

    # =====================
    # RQ Spline
    # ======================
    init_rq_f = InitPiecewiseRationalQuadraticCDF(
        n_bins=n_bins, range_min=range_min, range_max=range_max, **kwargs
    )
    # =====================
    # HouseHolder Transform
    # ======================
    n_reflections = n_reflections
    # initialize init function
    init_hh_f = InitHouseHolder(n_reflections=n_reflections, method=init_rotation)

    block_rngs = jax.random.split(rng, num=n_blocks)
    # rng = jax.random.split(jax.random.PRNGKey(42), n_blocks)
    # block_rngs = jax.random.split(jax.random.PRNGKey(42), n_blocks)

    itercount = itertools.count()
    bijectors = []

    X_g = X.copy()

    pbar = tqdm.tqdm(block_rngs)
    with pbar:
        for iblock, irng in enumerate(pbar):

            pbar.set_description(
                f"Initializing - Block: {iblock+1} | Layer {next(itercount)}"
            )
            # ======================
            # Forward Squeezing Transform
            # ======================
            # intialize bijector and transformation
            X_g, layer = init_nl_forward_f.transform_and_bijector(inputs=X_g,)
            # add bijector to list
            bijectors.append(layer)
            # ======================
            # RQ Spline
            # ======================
            # create keys for all inits
            irng, irq_rng = jax.random.split(irng, 2)

            # intialize bijector and transformation
            X_g, layer = init_rq_f.transform_and_bijector(
                inputs=X_g, rng=irq_rng, shape=X.shape[1:]
            )

            # add bijector to list
            bijectors.append(layer)

            # ======================
            # Inverse Squeezing Transform
            # ======================
            # intialize bijector and transformation
            X_g, layer = init_nl_inverse_f.transform_and_bijector(inputs=X_g,)
            # add bijector to list
            bijectors.append(layer)

            # ======================
            # HOUSEHOLDER
            # ======================
            pbar.set_description(
                f"Initializing - Block: {iblock+1} | Layer {next(itercount)}"
            )
            # create keys for all inits
            irng, hh_rng = jax.random.split(irng, 2)

            # intialize bijector and transformation
            X_g, layer = init_hh_f.transform_and_bijector(
                inputs=X_g, rng=hh_rng, n_features=n_features
            )

            bijectors.append(layer)

    # create base dist
    base_dist = Normal(jnp.zeros((n_features,)), jnp.ones((n_features,)))

    # create flow model
    gf_model = GaussianizationFlow(base_dist=base_dist, bijectors=bijectors)

    if return_transform:
        return X_g, gf_model
    else:

        return gf_model
class TFPCompatibleDistributionNormal(parameterized.TestCase):
  """Tests for Normal distribution."""

  def setUp(self):
    super().setUp()
    self._sample_shape = (np.int32(10),)
    self._seed = 42
    self._key = jax.random.PRNGKey(self._seed)
    self.assertion_fn = lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL)
    self.base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.]))
    self.values = jnp.array([1., -1.])
    self.distrax_second_dist = Normal(loc=-1., scale=0.8)
    self.tfp_second_dist = tfd.Normal(loc=-1., scale=0.8)

  @property
  def wrapped_dist(self):
    return tfp_compatible_distribution(self.base_dist)

  def test_event_shape(self):
    chex.assert_equal(self.wrapped_dist.event_shape, self.base_dist.event_shape)

  def test_batch_shape(self):
    chex.assert_equal(self.wrapped_dist.batch_shape, self.base_dist.batch_shape)

  @chex.all_variants
  def test_sample(self):
    def sample_fn(key):
      return self.wrapped_dist.sample(seed=key, sample_shape=self._sample_shape)
    sample_fn = self.variant(sample_fn)
    self.assertion_fn(
        sample_fn(self._key),
        self.base_dist.sample(sample_shape=self._sample_shape, seed=self._key))

  @chex.all_variants(with_pmap=False)
  @parameterized.named_parameters(
      ('mean', 'mean'),
      ('mode', 'mode'),
      ('median', 'median'),
      ('stddev', 'stddev'),
      ('variance', 'variance'),
      ('entropy', 'entropy'),
  )
  def test_method(self, method):
    try:
      expected_result = self.variant(getattr(self.base_dist, method))()
    except NotImplementedError:
      return
    except AttributeError:
      return
    result = self.variant(getattr(self.wrapped_dist, method))()
    self.assertion_fn(result, expected_result)

  @chex.all_variants
  @parameterized.named_parameters(
      ('log_prob', 'log_prob'),
      ('prob', 'prob'),
      ('log_cdf', 'log_cdf'),
      ('cdf', 'cdf'),
  )
  def test_method_with_value(self, method):
    try:
      expected_result = self.variant(
          getattr(self.base_dist, method))(self.values)
    except NotImplementedError:
      return
    except AttributeError:
      return
    result = self.variant(getattr(self.wrapped_dist, method))(self.values)
    self.assertion_fn(result, expected_result)

  @chex.all_variants
  @parameterized.named_parameters(
      ('kl_divergence', 'kl_divergence'),
      ('cross_entropy', 'cross_entropy'),
  )
  def test_with_two_distributions(self, method):
    """Test methods of the form listed below.

      D(distrax_distrib || wrapped_distrib),
      D(wrapped_distrib || distrax_distrib),
      D(tfp_distrib || wrapped_distrib),
      D(wrapped_distrib || tfp_distrib).

    Args:
      method: the method name to be tested
    """
    try:
      expected_result1 = self.variant(
          getattr(self.distrax_second_dist, method))(self.base_distribution)
      expected_result2 = self.variant(
          getattr(self.base_distribution, method))(self.distrax_second_dist)
    except NotImplementedError:
      return
    except AttributeError:
      return
    distrax_result1 = self.variant(getattr(self.distrax_second_dist, method))(
        self.wrapped_dist)
    distrax_result2 = self.variant(getattr(self.wrapped_dist, method))(
        self.distrax_second_dist)
    tfp_result1 = self.variant(getattr(self.tfp_second_dist, method))(
        self.wrapped_dist)
    tfp_result2 = self.variant(getattr(self.wrapped_dist, method))(
        self.tfp_second_dist)
    self.assertion_fn(distrax_result1, expected_result1)
    self.assertion_fn(distrax_result2, expected_result2)
    self.assertion_fn(tfp_result1, expected_result1)
    self.assertion_fn(tfp_result2, expected_result2)
 def test_with_sample(self):
   base_dist = Normal(0., 1.)
   wrapped_dist = tfp_compatible_distribution(base_dist)
   meta_dist = tfd.Sample(
       wrapped_dist, sample_shape=[1, 3], validate_args=True)
   meta_dist.log_prob(meta_dist.sample(2, seed=self._key))
Esempio n. 15
0
 def _make_components(self, key_loc, key_scale):
     components_shape = self.batch_shape + (self.num_components, )
     return Normal(
         loc=jax.random.normal(key=key_loc, shape=components_shape),
         scale=jax.random.uniform(key=key_scale, shape=components_shape) +
         0.5)