Esempio n. 1
0
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    A `tfb.Bijector` where the `.forward` method flattens and unconstrains
    points.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    to_chain = [jd_model.experimental_default_event_space_bijector()]
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())
    to_chain.append(flat_bijector)

    unconstrained_shapes = flat_bijector.inverse_event_shape_tensor(
        jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    reshapers = [
        reshape.Reshape(event_shape_out=x, event_shape_in=[-1])
        for x in unconstrained_shapes
    ]
    to_chain.append(joint_map.JointMap(bijectors=reshapers))

    size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes]
    to_chain.append(split.Split(num_or_size_splits=size_splits))

    return invert.Invert(chain.Chain(to_chain))
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    Two `tfb.Bijector`s where the `.forward` method flattens and unconstrains
    points, and the second may be used to initialize a step size.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    event_space_bij = jd_model.experimental_default_event_space_bijector()
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())

    unconstrained_shapes = event_space_bij(
        flat_bijector).inverse_event_shape_tensor(
            jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    unsplit = joint_map.JointMap(
        tf.nest.map_structure(
            lambda x: reshape.Reshape(event_shape_out=x, event_shape_in=[-1]),
            unconstrained_shapes))

    bij = invert.Invert(chain.Chain([event_space_bij, flat_bijector, unsplit]))
    step_size_bij = invert.Invert(flat_bijector)

    return bij, step_size_bij