示例#1
0
def get_fldj_theoretical(bijector,
                         x,
                         event_ndims,
                         input_to_unconstrained=None,
                         output_to_unconstrained=None):
    """Numerically approximate the forward log det Jacobian of a bijector.

  We compute the Jacobian of the chain
  output_to_unconst_vec(bijector(inverse(input_to_unconst_vec))) so that
  we're working with a full rank matrix.  We then adjust the resulting Jacobian
  for the unconstraining bijectors.

  Bijectors that constrain / unconstrain their inputs/outputs may not be
  testable with this method, since the composition above may reduce the test
  to something trivial.  However, bijectors that map within constrained spaces
  should be fine.

  Args:
    bijector: the bijector whose Jacobian we wish to approximate
    x: the value for which we want to approximate the Jacobian.  x must have
      a a single batch dimension for compatibility with tape.batch_jacobian.
    event_ndims: number of dimensions in an event
    input_to_unconstrained: bijector that maps the input to the above bijector
      to an unconstrained 1-D vector.  If unspecified, flatten the input into
      a 1-D vector according to its event_ndims.
    output_to_unconstrained: bijector that maps the output of the above bijector
      to an unconstrained 1-D vector.  If unspecified, flatten the input into
      a 1-D vector according to its event_ndims.

  Returns:
    A numerical approximation to the log det Jacobian of bijector.forward
    evaluated at x.
  """
    if input_to_unconstrained is None:
        input_to_unconstrained = reshape_bijector.Reshape(
            event_shape_in=x.shape[tensorshape_util.rank(x.shape) -
                                   event_ndims:],
            event_shape_out=[-1])
    if output_to_unconstrained is None:
        output_to_unconstrained = reshape_bijector.Reshape(
            event_shape_in=x.shape[tensorshape_util.rank(x.shape) -
                                   event_ndims:],
            event_shape_out=[-1])

    x = tf.convert_to_tensor(value=x)
    x_unconstrained = 1 * input_to_unconstrained.forward(x)

    with tf.GradientTape(persistent=True) as tape:
        tape.watch(x_unconstrained)
        f_x = bijector.forward(input_to_unconstrained.inverse(x_unconstrained))
        f_x_unconstrained = output_to_unconstrained.forward(f_x)
    jacobian = tape.batch_jacobian(f_x_unconstrained,
                                   x_unconstrained,
                                   experimental_use_pfor=False)

    return (tf.linalg.slogdet(jacobian).log_abs_determinant +
            input_to_unconstrained.forward_log_det_jacobian(
                x, event_ndims=event_ndims) -
            output_to_unconstrained.forward_log_det_jacobian(
                f_x, event_ndims=event_ndims))
示例#2
0
 def __init__(self, input_shape, block_size=2, validate_args=False, name=None):
   parameters = dict(locals())
   self._block_size = block_size
   _, h, w, c = prefer_static.split(input_shape, [-1, 1, 1, 1])
   h, w, c = h[0], w[0], c[0]
   n = self._block_size
   b = [
       reshape.Reshape(
           event_shape_out=[h * n, w * n, c // n**2],
           event_shape_in=[h, n, w, n, c // n**2]),
       transpose.Transpose(perm=[0, 3, 1, 4, 2]),
       reshape.Reshape(
           event_shape_in=[h, w, c],
           event_shape_out=[h, w, c // n**2, n, n]),
   ]
   super(Expand, self).__init__(b, name=name or 'Expand',
                                parameters=parameters)
示例#3
0
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    A `tfb.Bijector` where the `.forward` method flattens and unconstrains
    points.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    to_chain = [jd_model.experimental_default_event_space_bijector()]
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())
    to_chain.append(flat_bijector)

    unconstrained_shapes = flat_bijector.inverse_event_shape_tensor(
        jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    reshapers = [
        reshape.Reshape(event_shape_out=x, event_shape_in=[-1])
        for x in unconstrained_shapes
    ]
    to_chain.append(joint_map.JointMap(bijectors=reshapers))

    size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes]
    to_chain.append(split.Split(num_or_size_splits=size_splits))

    return invert.Invert(chain.Chain(to_chain))
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    Two `tfb.Bijector`s where the `.forward` method flattens and unconstrains
    points, and the second may be used to initialize a step size.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    event_space_bij = jd_model.experimental_default_event_space_bijector()
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())

    unconstrained_shapes = event_space_bij(
        flat_bijector).inverse_event_shape_tensor(
            jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    unsplit = joint_map.JointMap(
        tf.nest.map_structure(
            lambda x: reshape.Reshape(event_shape_out=x, event_shape_in=[-1]),
            unconstrained_shapes))

    bij = invert.Invert(chain.Chain([event_space_bij, flat_bijector, unsplit]))
    step_size_bij = invert.Invert(flat_bijector)

    return bij, step_size_bij
示例#5
0
def make_momentum_distribution(state_parts,
                               batch_shape,
                               running_variance_parts=None,
                               shard_axis_names=None):
    """Construct a momentum distribution from the running variance.

  This uses a running variance to construct a momentum distribution with the
  correct batch_shape and event_shape.

  Args:
    state_parts: List of `Tensor`.
    batch_shape: Batch shape.
    running_variance_parts: Optional, list of `Tensor`
       outputs of `tfp.experimental.stats.RunningVariance.variance()`. Defaults
       to ones with the same shape as state_parts.
    shard_axis_names: A structure of string names indicating how members of the
      state are sharded.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
    if running_variance_parts is None:
        running_variance_parts = tf.nest.map_structure(tf.ones_like,
                                                       state_parts)
    distributions = []
    batch_ndims = ps.rank_from_shape(batch_shape)
    use_sharded_jd = True
    if shard_axis_names is None:
        use_sharded_jd = False
        shard_axis_names = [None] * len(state_parts)
    for variance_part, state_part, shard_axes in zip(running_variance_parts,
                                                     state_parts,
                                                     shard_axis_names):
        event_shape = state_part.shape[batch_ndims:]
        if not tensorshape_util.is_fully_defined(event_shape):
            event_shape = ps.shape(state_part,
                                   name='state_part_shp')[batch_ndims:]
        variance_tiled = tf.broadcast_to(
            variance_part, ps.concat([batch_shape, event_shape], axis=0))
        nevt = ps.cast(ps.reduce_prod(event_shape), tf.int32)
        variance_flattened = tf.reshape(
            variance_tiled, ps.concat([batch_shape, [nevt]], axis=0))

        distribution = _CompositeTransformedDistribution(
            bijector=reshape.Reshape(event_shape_out=event_shape,
                                     name='reshape_mvnpfl'),
            distribution=(
                _CompositeMultivariateNormalPrecisionFactorLinearOperator(
                    precision_factor=tf.linalg.LinearOperatorDiag(
                        tf.math.sqrt(variance_flattened)),
                    precision=tf.linalg.LinearOperatorDiag(variance_flattened),
                    name='momentum')))
        if shard_axes:
            distribution = sharded.Sharded(distribution,
                                           shard_axis_name=shard_axes)
        distributions.append(distribution)
    if use_sharded_jd:
        jd = _CompositeShardedJointDistributionSequential(distributions)
    else:
        jd = _CompositeJointDistributionSequential(distributions)
    return maybe_make_list_and_batch_broadcast(jd, batch_shape)
示例#6
0
def get_fldj_theoretical(bijector,
                         x,
                         event_ndims,
                         inverse_event_ndims=None,
                         input_to_unconstrained=None,
                         output_to_unconstrained=None):
    """Numerically approximate the forward log det Jacobian of a bijector.

  We compute the Jacobian of the chain
  output_to_unconst_vec(bijector(inverse(input_to_unconst_vec))) so that
  we're working with a full rank matrix.  We then adjust the resulting Jacobian
  for the unconstraining bijectors.

  Bijectors that constrain / unconstrain their inputs/outputs may not be
  testable with this method, since the composition above may reduce the test
  to something trivial.  However, bijectors that map within constrained spaces
  should be fine.

  Args:
    bijector: the bijector whose Jacobian we wish to approximate
    x: the value for which we want to approximate the Jacobian. Must have rank
      at least `event_ndims`.
    event_ndims: number of dimensions in an event
    inverse_event_ndims: Integer describing the number of event dimensions for
      the bijector codomain. If None, then the value of `event_ndims` is used.
    input_to_unconstrained: bijector that maps the input to the above bijector
      to an unconstrained 1-D vector.  If unspecified, flatten the input into
      a 1-D vector according to its event_ndims.
    output_to_unconstrained: bijector that maps the output of the above bijector
      to an unconstrained 1-D vector.  If unspecified, flatten the input into
      a 1-D vector according to its event_ndims.

  Returns:
    fldj: A gradient-based evaluation of the log det Jacobian of
      `bijector.forward` at `x`.
  """
    if inverse_event_ndims is None:
        inverse_event_ndims = event_ndims
    if input_to_unconstrained is None:
        input_to_unconstrained = reshape_bijector.Reshape(
            event_shape_in=x.shape[tensorshape_util.rank(x.shape) -
                                   event_ndims:],
            event_shape_out=[-1])
    if output_to_unconstrained is None:
        output_to_unconstrained = reshape_bijector.Reshape(
            event_shape_in=x.shape[tensorshape_util.rank(x.shape) -
                                   event_ndims:],
            event_shape_out=[-1])

    x = tf.convert_to_tensor(x)
    x_unconstrained = 1 * input_to_unconstrained.forward(x)
    # Collapse any batch dimensions (including scalar) to a single axis.
    batch_shape = x_unconstrained.shape[:-1]
    x_unconstrained = tf.reshape(
        x_unconstrained,
        [int(np.prod(batch_shape)), x_unconstrained.shape[-1]])

    with tf.GradientTape(persistent=True) as tape:
        tape.watch(x_unconstrained)
        # Unflatten any batch dimensions now under the tape.
        unflattened_x_unconstrained = tf.reshape(
            x_unconstrained,
            tensorshape_util.concatenate(batch_shape,
                                         x_unconstrained.shape[-1:]))
        f_x = bijector.forward(
            input_to_unconstrained.inverse(unflattened_x_unconstrained))
        f_x_unconstrained = output_to_unconstrained.forward(f_x)
        # Flatten any batch dimensions to a single axis.
        f_x_unconstrained = tf.reshape(
            f_x_unconstrained,
            [int(np.prod(batch_shape)), f_x_unconstrained.shape[-1]])
    try:
        jacobian = tape.batch_jacobian(f_x_unconstrained, x_unconstrained)
    except ValueError:  # Fallback to for-loop jacobian.
        jacobian = tape.batch_jacobian(f_x_unconstrained,
                                       x_unconstrained,
                                       experimental_use_pfor=False)
    jacobian = tf.reshape(
        jacobian, tensorshape_util.concatenate(batch_shape,
                                               jacobian.shape[-2:]))
    logging.vlog(1, 'Jacobian: %s', jacobian)

    log_det_jacobian = 0.5 * tf.linalg.slogdet(
        tf.matmul(jacobian, jacobian, adjoint_a=True)).log_abs_determinant

    input_correction = input_to_unconstrained.forward_log_det_jacobian(
        x, event_ndims=event_ndims)
    output_correction = output_to_unconstrained.forward_log_det_jacobian(
        f_x, event_ndims=inverse_event_ndims)
    return log_det_jacobian + input_correction - output_correction
示例#7
0
  def __init__(self,
               output_shape=(32, 32, 3),
               num_glow_blocks=3,
               num_steps_per_block=32,
               coupling_bijector_fn=None,
               exit_bijector_fn=None,
               grab_after_block=None,
               use_actnorm=True,
               seed=None,
               validate_args=False,
               name='glow'):
    """Creates the Glow bijector.

    Args:
      output_shape: A list of integers, specifying the event shape of the
        output, of the bijectors forward pass (the image).  Specified as
        [H, W, C].
        Default Value: (32, 32, 3)
      num_glow_blocks: An integer, specifying how many downsampling levels to
        include in the model. This must divide equally into both H and W,
        otherwise the bijector would not be invertible.
        Default Value: 3
      num_steps_per_block: An integer specifying how many Affine Coupling and
        1x1 convolution layers to include at each level of the spatial
        hierarchy.
        Default Value: 32 (i.e. the value used in the original glow paper).
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras.Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
      exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is
        a function which takes the argument `input_shape` and `output_chan`
        and returns a callable neural network. The neural network it returns
        should take a tensor of shape `input_shape` as the input, and return
        one of three options: A tensor with `output_chan` channels, a tensor
        with `2 * output_chan` channels, or a bijector. Additional details can
        be found in the documentation for ExitBijector.
      grab_after_block: A tuple of floats, specifying what fraction of the
        remaining channels to remove following each glow block. Glow will take
        the integer floor of this number multiplied by the remaining number of
        channels. The default is half at each spatial hierarchy.
        Default value: None (this will take out half of the channels after each
          block.
      use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent
        initialization is used to initialize this layer.
        Default value: `False`
      seed: A seed to control randomness in the 1x1 convolution initialization.
        Default value: `None` (i.e., non-reproducible sampling).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False`
      name: Python `str`, name given to ops managed by this object.
        Default value: `'glow'`.
    """
    # Make sure that the input shape is fully defined.
    if not tensorshape_util.is_fully_defined(output_shape):
      raise ValueError('Shape must be fully defined.')
    if tensorshape_util.rank(output_shape) != 3:
      raise ValueError('Shape ndims must be 3 for images.  Your shape is'
                       '{}'.format(tensorshape_util.rank(output_shape)))

    num_glow_blocks_ = tf.get_static_value(num_glow_blocks)
    if (num_glow_blocks_ is None or
        int(num_glow_blocks_) != num_glow_blocks_ or
        num_glow_blocks_ < 1):
      raise ValueError('Argument `num_glow_blocks` must be a statically known'
                       'positive `int` (saw: {}).'.format(num_glow_blocks))
    num_glow_blocks = int(num_glow_blocks_)

    output_shape = tensorshape_util.as_list(output_shape)
    h, w, c = output_shape
    n = num_glow_blocks
    nsteps = num_steps_per_block

    # Default Glow: Half of the channels are split off after each block,
    # and after the final block, no channels are split off.
    if grab_after_block is None:
      grab_after_block = tuple([0.5] * (n - 1) + [0.])

    # Thing we know must be true: h and w are evenly divisible by 2, n times.
    # Otherwise, the squeeze bijector will not work.
    if w % 2**n != 0:
      raise ValueError('Width must be divisible by 2 at least n times.'
                       'Saw: {} % {} != 0'.format(w, 2**n))
    if h % 2**n != 0:
      raise ValueError('Height should be divisible by 2 at least n times.')
    if h // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image height '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, h,
                                       int(np.log(h) / np.log(2.))))
    if w // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image width '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, w,
                                       int(np.log(h) / np.log(2.))))

    # Other things we want to be true:
    # - The number of times we take must be equal to the number of glow blocks.
    if len(grab_after_block) != num_glow_blocks:
      raise ValueError('Length of grab_after_block ({0}) must match the number'
                       'of blocks ({1}).'.format(len(grab_after_block),
                                                 num_glow_blocks))

    self._blockwise_splits = self._get_blockwise_splits(output_shape,
                                                        grab_after_block[::-1])

    # Now check on the values of blockwise splits
    if any([bs[0] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[0] for bs in self._blockwise_splits].index(True)
      raise ValueError('At at least one exit, you are taking out all of your '
                       'channels, and therefore have no inputs to later blocks.'
                       ' Try setting grab_after_block to a lower value at index'
                       '{}.'.format(first_offender))

    if any(np.isclose(gab, 0) for gab in grab_after_block):
      # Special case: if specifically exiting no channels, then the exit is
      # just an identity bijector.
      pass
    elif any([bs[1] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[1] for bs in self._blockwise_splits].index(True)
      raise ValueError('At least one of your layers has < 1 output channels. '
                       'This means you set grab_at_block too small. '
                       'Try setting grab_after_block to a larger value at index'
                       '{}.'.format(first_offender))

    # Lets start to build our bijector. We assume that the distribution is 1
    # dimensional. First, lets reshape it to an image.
    glow_chain = [
        reshape.Reshape(
            event_shape_out=[h // 2**n, w // 2**n, c * 4**n],
            event_shape_in=[h * w * c])
    ]

    seedstream = SeedStream(seed=seed, salt='random_beta')

    for i in range(n):

      # This is the shape of the current tensor
      current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1))

      # This is the shape of the input to both the glow block and exit bijector.
      this_nchan = sum(self._blockwise_splits[i][0:2])
      this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan)

      glow_chain.append(invert.Invert(ExitBijector(current_shape,
                                                   self._blockwise_splits[i],
                                                   exit_bijector_fn)))

      glow_block = GlowBlock(input_shape=this_input_shape,
                             num_steps=nsteps,
                             coupling_bijector_fn=coupling_bijector_fn,
                             use_actnorm=use_actnorm,
                             seedstream=seedstream)

      if self._blockwise_splits[i][2] == 0:
        # All channels are passed to the RealNVP
        glow_chain.append(glow_block)
      else:
        # Some channels are passed around the block.
        # This is done with the Blockwise bijector.
        glow_chain.append(
            blockwise.Blockwise(
                [glow_block, identity.Identity()],
                [sum(self._blockwise_splits[i][0:2]),
                 self._blockwise_splits[i][2]]))

      # Finally, lets expand the channels into spatial features.
      glow_chain.append(
          Expand(input_shape=[
              h // 2**n * 2**i,
              w // 2**n * 2**i,
              c * 4**n // 4**i,
          ]))

    glow_chain = glow_chain[::-1]
    # To finish off, we initialize the bijector with the chain we've built
    # This way, the rest of the model attributes are taken care of for us.
    super(Glow, self).__init__(
        bijectors=glow_chain, validate_args=validate_args, name=name)