def get_fldj_theoretical(bijector, x, event_ndims, input_to_unconstrained=None, output_to_unconstrained=None): """Numerically approximate the forward log det Jacobian of a bijector. We compute the Jacobian of the chain output_to_unconst_vec(bijector(inverse(input_to_unconst_vec))) so that we're working with a full rank matrix. We then adjust the resulting Jacobian for the unconstraining bijectors. Bijectors that constrain / unconstrain their inputs/outputs may not be testable with this method, since the composition above may reduce the test to something trivial. However, bijectors that map within constrained spaces should be fine. Args: bijector: the bijector whose Jacobian we wish to approximate x: the value for which we want to approximate the Jacobian. x must have a a single batch dimension for compatibility with tape.batch_jacobian. event_ndims: number of dimensions in an event input_to_unconstrained: bijector that maps the input to the above bijector to an unconstrained 1-D vector. If unspecified, flatten the input into a 1-D vector according to its event_ndims. output_to_unconstrained: bijector that maps the output of the above bijector to an unconstrained 1-D vector. If unspecified, flatten the input into a 1-D vector according to its event_ndims. Returns: A numerical approximation to the log det Jacobian of bijector.forward evaluated at x. """ if input_to_unconstrained is None: input_to_unconstrained = reshape_bijector.Reshape( event_shape_in=x.shape[tensorshape_util.rank(x.shape) - event_ndims:], event_shape_out=[-1]) if output_to_unconstrained is None: output_to_unconstrained = reshape_bijector.Reshape( event_shape_in=x.shape[tensorshape_util.rank(x.shape) - event_ndims:], event_shape_out=[-1]) x = tf.convert_to_tensor(value=x) x_unconstrained = 1 * input_to_unconstrained.forward(x) with tf.GradientTape(persistent=True) as tape: tape.watch(x_unconstrained) f_x = bijector.forward(input_to_unconstrained.inverse(x_unconstrained)) f_x_unconstrained = output_to_unconstrained.forward(f_x) jacobian = tape.batch_jacobian(f_x_unconstrained, x_unconstrained, experimental_use_pfor=False) return (tf.linalg.slogdet(jacobian).log_abs_determinant + input_to_unconstrained.forward_log_det_jacobian( x, event_ndims=event_ndims) - output_to_unconstrained.forward_log_det_jacobian( f_x, event_ndims=event_ndims))
def __init__(self, input_shape, block_size=2, validate_args=False, name=None): parameters = dict(locals()) self._block_size = block_size _, h, w, c = prefer_static.split(input_shape, [-1, 1, 1, 1]) h, w, c = h[0], w[0], c[0] n = self._block_size b = [ reshape.Reshape( event_shape_out=[h * n, w * n, c // n**2], event_shape_in=[h, n, w, n, c // n**2]), transpose.Transpose(perm=[0, 3, 1, 4, 2]), reshape.Reshape( event_shape_in=[h, w, c], event_shape_out=[h, w, c // n**2, n, n]), ] super(Expand, self).__init__(b, name=name or 'Expand', parameters=parameters)
def _get_flat_unconstraining_bijector(jd_model): """Create a bijector from a joint distribution that flattens and unconstrains. The intention is (loosely) to go from a model joint distribution supported on U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j} to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense of base measures: some distribution may be supported on an m-dimensional subset of R^n, and the default transform for that distribution may then have support on R^m. See [1] for details. Args: jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a model. Returns: A `tfb.Bijector` where the `.forward` method flattens and unconstrains points. """ # TODO(b/180396233): This bijector is in general point-dependent. to_chain = [jd_model.experimental_default_event_space_bijector()] flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor()) to_chain.append(flat_bijector) unconstrained_shapes = flat_bijector.inverse_event_shape_tensor( jd_model.event_shape_tensor()) # this reshaping is required as as split can produce a tensor of shape [1] # when the distribution event shape is [] reshapers = [ reshape.Reshape(event_shape_out=x, event_shape_in=[-1]) for x in unconstrained_shapes ] to_chain.append(joint_map.JointMap(bijectors=reshapers)) size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes] to_chain.append(split.Split(num_or_size_splits=size_splits)) return invert.Invert(chain.Chain(to_chain))
def _get_flat_unconstraining_bijector(jd_model): """Create a bijector from a joint distribution that flattens and unconstrains. The intention is (loosely) to go from a model joint distribution supported on U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j} to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense of base measures: some distribution may be supported on an m-dimensional subset of R^n, and the default transform for that distribution may then have support on R^m. See [1] for details. Args: jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a model. Returns: Two `tfb.Bijector`s where the `.forward` method flattens and unconstrains points, and the second may be used to initialize a step size. """ # TODO(b/180396233): This bijector is in general point-dependent. event_space_bij = jd_model.experimental_default_event_space_bijector() flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor()) unconstrained_shapes = event_space_bij( flat_bijector).inverse_event_shape_tensor( jd_model.event_shape_tensor()) # this reshaping is required as as split can produce a tensor of shape [1] # when the distribution event shape is [] unsplit = joint_map.JointMap( tf.nest.map_structure( lambda x: reshape.Reshape(event_shape_out=x, event_shape_in=[-1]), unconstrained_shapes)) bij = invert.Invert(chain.Chain([event_space_bij, flat_bijector, unsplit])) step_size_bij = invert.Invert(flat_bijector) return bij, step_size_bij
def make_momentum_distribution(state_parts, batch_shape, running_variance_parts=None, shard_axis_names=None): """Construct a momentum distribution from the running variance. This uses a running variance to construct a momentum distribution with the correct batch_shape and event_shape. Args: state_parts: List of `Tensor`. batch_shape: Batch shape. running_variance_parts: Optional, list of `Tensor` outputs of `tfp.experimental.stats.RunningVariance.variance()`. Defaults to ones with the same shape as state_parts. shard_axis_names: A structure of string names indicating how members of the state are sharded. Returns: `tfd.Distribution` where `.sample` has the same structure as `state_parts`, and `.log_prob` of the sample will have the rank of `batch_ndims` """ if running_variance_parts is None: running_variance_parts = tf.nest.map_structure(tf.ones_like, state_parts) distributions = [] batch_ndims = ps.rank_from_shape(batch_shape) use_sharded_jd = True if shard_axis_names is None: use_sharded_jd = False shard_axis_names = [None] * len(state_parts) for variance_part, state_part, shard_axes in zip(running_variance_parts, state_parts, shard_axis_names): event_shape = state_part.shape[batch_ndims:] if not tensorshape_util.is_fully_defined(event_shape): event_shape = ps.shape(state_part, name='state_part_shp')[batch_ndims:] variance_tiled = tf.broadcast_to( variance_part, ps.concat([batch_shape, event_shape], axis=0)) nevt = ps.cast(ps.reduce_prod(event_shape), tf.int32) variance_flattened = tf.reshape( variance_tiled, ps.concat([batch_shape, [nevt]], axis=0)) distribution = _CompositeTransformedDistribution( bijector=reshape.Reshape(event_shape_out=event_shape, name='reshape_mvnpfl'), distribution=( _CompositeMultivariateNormalPrecisionFactorLinearOperator( precision_factor=tf.linalg.LinearOperatorDiag( tf.math.sqrt(variance_flattened)), precision=tf.linalg.LinearOperatorDiag(variance_flattened), name='momentum'))) if shard_axes: distribution = sharded.Sharded(distribution, shard_axis_name=shard_axes) distributions.append(distribution) if use_sharded_jd: jd = _CompositeShardedJointDistributionSequential(distributions) else: jd = _CompositeJointDistributionSequential(distributions) return maybe_make_list_and_batch_broadcast(jd, batch_shape)
def get_fldj_theoretical(bijector, x, event_ndims, inverse_event_ndims=None, input_to_unconstrained=None, output_to_unconstrained=None): """Numerically approximate the forward log det Jacobian of a bijector. We compute the Jacobian of the chain output_to_unconst_vec(bijector(inverse(input_to_unconst_vec))) so that we're working with a full rank matrix. We then adjust the resulting Jacobian for the unconstraining bijectors. Bijectors that constrain / unconstrain their inputs/outputs may not be testable with this method, since the composition above may reduce the test to something trivial. However, bijectors that map within constrained spaces should be fine. Args: bijector: the bijector whose Jacobian we wish to approximate x: the value for which we want to approximate the Jacobian. Must have rank at least `event_ndims`. event_ndims: number of dimensions in an event inverse_event_ndims: Integer describing the number of event dimensions for the bijector codomain. If None, then the value of `event_ndims` is used. input_to_unconstrained: bijector that maps the input to the above bijector to an unconstrained 1-D vector. If unspecified, flatten the input into a 1-D vector according to its event_ndims. output_to_unconstrained: bijector that maps the output of the above bijector to an unconstrained 1-D vector. If unspecified, flatten the input into a 1-D vector according to its event_ndims. Returns: fldj: A gradient-based evaluation of the log det Jacobian of `bijector.forward` at `x`. """ if inverse_event_ndims is None: inverse_event_ndims = event_ndims if input_to_unconstrained is None: input_to_unconstrained = reshape_bijector.Reshape( event_shape_in=x.shape[tensorshape_util.rank(x.shape) - event_ndims:], event_shape_out=[-1]) if output_to_unconstrained is None: output_to_unconstrained = reshape_bijector.Reshape( event_shape_in=x.shape[tensorshape_util.rank(x.shape) - event_ndims:], event_shape_out=[-1]) x = tf.convert_to_tensor(x) x_unconstrained = 1 * input_to_unconstrained.forward(x) # Collapse any batch dimensions (including scalar) to a single axis. batch_shape = x_unconstrained.shape[:-1] x_unconstrained = tf.reshape( x_unconstrained, [int(np.prod(batch_shape)), x_unconstrained.shape[-1]]) with tf.GradientTape(persistent=True) as tape: tape.watch(x_unconstrained) # Unflatten any batch dimensions now under the tape. unflattened_x_unconstrained = tf.reshape( x_unconstrained, tensorshape_util.concatenate(batch_shape, x_unconstrained.shape[-1:])) f_x = bijector.forward( input_to_unconstrained.inverse(unflattened_x_unconstrained)) f_x_unconstrained = output_to_unconstrained.forward(f_x) # Flatten any batch dimensions to a single axis. f_x_unconstrained = tf.reshape( f_x_unconstrained, [int(np.prod(batch_shape)), f_x_unconstrained.shape[-1]]) try: jacobian = tape.batch_jacobian(f_x_unconstrained, x_unconstrained) except ValueError: # Fallback to for-loop jacobian. jacobian = tape.batch_jacobian(f_x_unconstrained, x_unconstrained, experimental_use_pfor=False) jacobian = tf.reshape( jacobian, tensorshape_util.concatenate(batch_shape, jacobian.shape[-2:])) logging.vlog(1, 'Jacobian: %s', jacobian) log_det_jacobian = 0.5 * tf.linalg.slogdet( tf.matmul(jacobian, jacobian, adjoint_a=True)).log_abs_determinant input_correction = input_to_unconstrained.forward_log_det_jacobian( x, event_ndims=event_ndims) output_correction = output_to_unconstrained.forward_log_det_jacobian( f_x, event_ndims=inverse_event_ndims) return log_det_jacobian + input_correction - output_correction
def __init__(self, output_shape=(32, 32, 3), num_glow_blocks=3, num_steps_per_block=32, coupling_bijector_fn=None, exit_bijector_fn=None, grab_after_block=None, use_actnorm=True, seed=None, validate_args=False, name='glow'): """Creates the Glow bijector. Args: output_shape: A list of integers, specifying the event shape of the output, of the bijectors forward pass (the image). Specified as [H, W, C]. Default Value: (32, 32, 3) num_glow_blocks: An integer, specifying how many downsampling levels to include in the model. This must divide equally into both H and W, otherwise the bijector would not be invertible. Default Value: 3 num_steps_per_block: An integer specifying how many Affine Coupling and 1x1 convolution layers to include at each level of the spatial hierarchy. Default Value: 32 (i.e. the value used in the original glow paper). coupling_bijector_fn: A function which takes the argument `input_shape` and returns a callable neural network (e.g. a keras.Sequential). The network should either return a tensor with the same event shape as `input_shape` (this will employ additive coupling), a tensor with the same height and width as `input_shape` but twice the number of channels (this will employ affine coupling), or a bijector which takes in a tensor with event shape `input_shape`, and returns a tensor with shape `input_shape`. exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is a function which takes the argument `input_shape` and `output_chan` and returns a callable neural network. The neural network it returns should take a tensor of shape `input_shape` as the input, and return one of three options: A tensor with `output_chan` channels, a tensor with `2 * output_chan` channels, or a bijector. Additional details can be found in the documentation for ExitBijector. grab_after_block: A tuple of floats, specifying what fraction of the remaining channels to remove following each glow block. Glow will take the integer floor of this number multiplied by the remaining number of channels. The default is half at each spatial hierarchy. Default value: None (this will take out half of the channels after each block. use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent initialization is used to initialize this layer. Default value: `False` seed: A seed to control randomness in the 1x1 convolution initialization. Default value: `None` (i.e., non-reproducible sampling). validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` name: Python `str`, name given to ops managed by this object. Default value: `'glow'`. """ # Make sure that the input shape is fully defined. if not tensorshape_util.is_fully_defined(output_shape): raise ValueError('Shape must be fully defined.') if tensorshape_util.rank(output_shape) != 3: raise ValueError('Shape ndims must be 3 for images. Your shape is' '{}'.format(tensorshape_util.rank(output_shape))) num_glow_blocks_ = tf.get_static_value(num_glow_blocks) if (num_glow_blocks_ is None or int(num_glow_blocks_) != num_glow_blocks_ or num_glow_blocks_ < 1): raise ValueError('Argument `num_glow_blocks` must be a statically known' 'positive `int` (saw: {}).'.format(num_glow_blocks)) num_glow_blocks = int(num_glow_blocks_) output_shape = tensorshape_util.as_list(output_shape) h, w, c = output_shape n = num_glow_blocks nsteps = num_steps_per_block # Default Glow: Half of the channels are split off after each block, # and after the final block, no channels are split off. if grab_after_block is None: grab_after_block = tuple([0.5] * (n - 1) + [0.]) # Thing we know must be true: h and w are evenly divisible by 2, n times. # Otherwise, the squeeze bijector will not work. if w % 2**n != 0: raise ValueError('Width must be divisible by 2 at least n times.' 'Saw: {} % {} != 0'.format(w, 2**n)) if h % 2**n != 0: raise ValueError('Height should be divisible by 2 at least n times.') if h // 2**n < 1: raise ValueError('num_glow_blocks ({0}) is too large. The image height ' '({1}) must be divisible by 2 no more than {2} ' 'times.'.format(num_glow_blocks, h, int(np.log(h) / np.log(2.)))) if w // 2**n < 1: raise ValueError('num_glow_blocks ({0}) is too large. The image width ' '({1}) must be divisible by 2 no more than {2} ' 'times.'.format(num_glow_blocks, w, int(np.log(h) / np.log(2.)))) # Other things we want to be true: # - The number of times we take must be equal to the number of glow blocks. if len(grab_after_block) != num_glow_blocks: raise ValueError('Length of grab_after_block ({0}) must match the number' 'of blocks ({1}).'.format(len(grab_after_block), num_glow_blocks)) self._blockwise_splits = self._get_blockwise_splits(output_shape, grab_after_block[::-1]) # Now check on the values of blockwise splits if any([bs[0] < 1 for bs in self._blockwise_splits]): first_offender = [bs[0] for bs in self._blockwise_splits].index(True) raise ValueError('At at least one exit, you are taking out all of your ' 'channels, and therefore have no inputs to later blocks.' ' Try setting grab_after_block to a lower value at index' '{}.'.format(first_offender)) if any(np.isclose(gab, 0) for gab in grab_after_block): # Special case: if specifically exiting no channels, then the exit is # just an identity bijector. pass elif any([bs[1] < 1 for bs in self._blockwise_splits]): first_offender = [bs[1] for bs in self._blockwise_splits].index(True) raise ValueError('At least one of your layers has < 1 output channels. ' 'This means you set grab_at_block too small. ' 'Try setting grab_after_block to a larger value at index' '{}.'.format(first_offender)) # Lets start to build our bijector. We assume that the distribution is 1 # dimensional. First, lets reshape it to an image. glow_chain = [ reshape.Reshape( event_shape_out=[h // 2**n, w // 2**n, c * 4**n], event_shape_in=[h * w * c]) ] seedstream = SeedStream(seed=seed, salt='random_beta') for i in range(n): # This is the shape of the current tensor current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1)) # This is the shape of the input to both the glow block and exit bijector. this_nchan = sum(self._blockwise_splits[i][0:2]) this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan) glow_chain.append(invert.Invert(ExitBijector(current_shape, self._blockwise_splits[i], exit_bijector_fn))) glow_block = GlowBlock(input_shape=this_input_shape, num_steps=nsteps, coupling_bijector_fn=coupling_bijector_fn, use_actnorm=use_actnorm, seedstream=seedstream) if self._blockwise_splits[i][2] == 0: # All channels are passed to the RealNVP glow_chain.append(glow_block) else: # Some channels are passed around the block. # This is done with the Blockwise bijector. glow_chain.append( blockwise.Blockwise( [glow_block, identity.Identity()], [sum(self._blockwise_splits[i][0:2]), self._blockwise_splits[i][2]])) # Finally, lets expand the channels into spatial features. glow_chain.append( Expand(input_shape=[ h // 2**n * 2**i, w // 2**n * 2**i, c * 4**n // 4**i, ])) glow_chain = glow_chain[::-1] # To finish off, we initialize the bijector with the chain we've built # This way, the rest of the model attributes are taken care of for us. super(Glow, self).__init__( bijectors=glow_chain, validate_args=validate_args, name=name)