def trace( state: State, fn: TransitionOperator, num_steps: IntTensor, trace_fn: Callable[[State, TensorNest], TensorNest], parallel_iterations: int = 10, ) -> Tuple[State, TensorNest]: """`TransitionOperator` that runs `fn` repeatedly and traces its outputs. Args: state: A nest of `Tensor`s or None. fn: A `TransitionOperator`. num_steps: Number of steps to run the function for. Must be greater than 1. trace_fn: Callable that the unpacked outputs of `fn` and returns a nest of `Tensor`s. These will be stacked and returned. parallel_iterations: Number of iterations of the while loop to run in parallel. Returns: state: The final state returned by `fn`. traces: Stacked outputs of `trace_fn`. """ state = tf.nest.map_structure( lambda t: t if t is None else tf.convert_to_tensor(t), state) def wrapper(state): state, extra = tf.nest.map_structure(tf.convert_to_tensor, call_fn(fn, state)) trace_element = tf.nest.map_structure(tf.convert_to_tensor, trace_fn(state, extra)) return state, trace_element if any(e is None for e in tf.nest.flatten(state)) or tf.executing_eagerly(): state, first_trace = wrapper(state) trace_arrays = tf.nest.map_structure( lambda v: tf.TensorArray( # pylint: disable=g-long-lambda v.dtype, size=num_steps, element_shape=v.shape).write(0, v), first_trace) start_idx = 1 else: state_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state) # We need the shapes and dtypes of the outputs of `wrapper` function to # create the `TensorArray`s, we can get it by pre-compiling the wrapper # function. wrapper = tf.function(autograph=False)(wrapper) concrete_wrapper = wrapper.get_concrete_function(state_spec) _, trace_dtypes = concrete_wrapper.output_dtypes _, trace_shapes = concrete_wrapper.output_shapes trace_arrays = tf.nest.map_structure( lambda dtype, shape: tf.TensorArray( # pylint: disable=g-long-lambda dtype, size=num_steps, element_shape=shape), trace_dtypes, trace_shapes) wrapper = lambda state: concrete_wrapper(*tf.nest.flatten(state)) start_idx = 0 def body(i, state, trace_arrays): state, trace_element = wrapper(state) trace_arrays = tf.nest.map_structure(lambda a, v: a.write(i, v), trace_arrays, trace_element) return i + 1, state, trace_arrays def cond(i, *_): return i < num_steps _, state, trace_arrays = tf.while_loop( cond=cond, body=body, loop_vars=(start_idx, state, trace_arrays), parallel_iterations=parallel_iterations) stacked_trace = tf.nest.map_structure(lambda x: x.stack(), trace_arrays) static_length = tf.get_static_value(num_steps) def _merge_static_length(x): x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:])) return x stacked_trace = tf.nest.map_structure(_merge_static_length, stacked_trace) return state, stacked_trace
def __init__(self, target_log_prob_fn, step_size, max_tree_depth=10, max_energy_diff=1000., unrolled_leapfrog_steps=1, seed=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e. the number of nodes in a binary tree `max_tree_depth` nodes deep. The default setting of 10 takes up to 1024 leapfrog steps. max_energy_diff: Scaler threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000. unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1. seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'nuts_kernel'). """ with tf.name_scope(name or 'NoUTurnSampler') as name: # Process `max_tree_depth` argument. max_tree_depth = tf.get_static_value(max_tree_depth) if max_tree_depth is None or max_tree_depth < 1: raise ValueError( 'max_tree_depth must be known statically and >= 1 but was ' '{}'.format(max_tree_depth)) self._max_tree_depth = max_tree_depth # Compute parameters derived from `max_tree_depth`. instruction_array = build_tree_uturn_instruction(max_tree_depth, init_memory=-1) [write_instruction_numpy, read_instruction_numpy ] = generate_efficient_write_read_instruction(instruction_array) # TensorArray version of the read/write instruction need to be created # within the function call to be compatible with XLA. Here we store the # numpy version of the instruction and convert it to TensorArray later. self._write_instruction = write_instruction_numpy self._read_instruction = read_instruction_numpy # Process all other arguments. self._target_log_prob_fn = target_log_prob_fn if not tf.nest.is_nested(step_size): step_size = [step_size] step_size = [ tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in step_size ] self._step_size = step_size self._parameters = dict( target_log_prob_fn=target_log_prob_fn, step_size=step_size, max_tree_depth=max_tree_depth, max_energy_diff=max_energy_diff, unrolled_leapfrog_steps=unrolled_leapfrog_steps, seed=seed, name=name, ) self._seed_stream = SeedStream(seed, salt='nuts_one_step') self._unrolled_leapfrog_steps = unrolled_leapfrog_steps self._name = name self._max_energy_diff = max_energy_diff
def canonicalize_observed_time_series_with_mask( maybe_masked_observed_time_series): """Extract a Tensor with canonical shape and optional mask. Args: maybe_masked_observed_time_series: a `Tensor`-like object with shape `[..., num_timesteps]` or `[..., num_timesteps, 1]`, or a `tfp.sts.MaskedTimeSeries` containing such an object, or a Pandas Series or DataFrame instance with set frequency (i.e., `.index.freq is not None`). Returns: masked_time_series: a `tfp.sts.MaskedTimeSeries` namedtuple, in which the `observed_time_series` is converted to `Tensor` with canonical shape `[..., num_timesteps, 1]`, and `is_missing` is either `None` or a boolean `Tensor`. """ with tf.name_scope('canonicalize_observed_time_series_with_mask'): is_missing_is_specified = hasattr(maybe_masked_observed_time_series, 'is_missing') if is_missing_is_specified: # Input is a MaskedTimeSeries. observed_time_series = ( maybe_masked_observed_time_series.time_series) is_missing = maybe_masked_observed_time_series.is_missing elif (hasattr(maybe_masked_observed_time_series, 'index') and hasattr(maybe_masked_observed_time_series, 'to_numpy')): # Input is a Pandas Series or DataFrame. index = maybe_masked_observed_time_series.index if hasattr(index, 'freq') and index.freq is None: raise ValueError( 'Pandas DataFrame or Series has a DatetimeIndex with ' 'no set frequency, but STS requires regularly spaced ' 'observations. Consider using ' '`tfp.sts.regularize_series` to infer a frequency and ' 'build a regularly spaced series (by marking ' 'unobserved steps as missing observations).') # When a DataFrame has multiple columns representing a batch of series, # we want shape `[batch_size, num_steps]` rather than vice versa. observed_time_series = np.squeeze( np.transpose(maybe_masked_observed_time_series.to_numpy())) else: observed_time_series = maybe_masked_observed_time_series observed_time_series = tf.convert_to_tensor( value=observed_time_series, name='observed_time_series') observed_time_series = _maybe_expand_trailing_dim(observed_time_series) # Treat `NaN` values as missing. if not is_missing_is_specified: is_missing = tf.math.is_nan(observed_time_series[..., 0]) is_missing_static = tf.get_static_value(is_missing) if is_missing_static is not None and not np.any(is_missing_static): is_missing = None if is_missing is not None: is_missing = tf.convert_to_tensor(value=is_missing, name='is_missing', dtype_hint=tf.bool) return missing_values_util.MaskedTimeSeries(observed_time_series, is_missing=is_missing)
def __init__(self, perm=None, rightmost_transposed_ndims=None, validate_args=False, name='transpose'): """Instantiates the `Transpose` bijector. Args: perm: Positive `int32` vector-shaped `Tensor` representing permutation of rightmost dims (for forward transformation). Note that the `0`th index represents the first of the rightmost dims and the largest value must be `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. Default value: `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`. rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor` representing the number of rightmost dimensions to permute. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. Default value: `tf.size(perm)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are specified. NotImplementedError: if `rightmost_transposed_ndims` is not known prior to graph execution. """ with tf.name_scope(name): if (rightmost_transposed_ndims is None) == (perm is None): raise ValueError('Must specify exactly one of ' '`rightmost_transposed_ndims` and `perm`.') if rightmost_transposed_ndims is not None: rightmost_transposed_ndims = tf.convert_to_tensor( value=rightmost_transposed_ndims, dtype=np.int32, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) assertions = _maybe_validate_rightmost_transposed_ndims( rightmost_transposed_ndims, validate_args) if assertions: with tf.control_dependencies(assertions): rightmost_transposed_ndims = tf.identity( rightmost_transposed_ndims) perm_start = (distribution_util.prefer_static_value( rightmost_transposed_ndims) - 1) perm = tf.range(start=perm_start, limit=-1, delta=-1, name='perm') else: # perm is not None: perm = tf.convert_to_tensor(value=perm, dtype=np.int32, name='perm') rightmost_transposed_ndims = tf.size( input=perm, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) assertions = _maybe_validate_perm(perm, validate_args) if assertions: with tf.control_dependencies(assertions): perm = tf.identity(perm) # TODO(b/110828604): If bijector base class ever supports dynamic # `min_event_ndims`, then this class already works dynamically and the # following five lines can be removed. if rightmost_transposed_ndims_ is None: raise NotImplementedError( '`rightmost_transposed_ndims` must be ' 'known prior to graph execution.') else: rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_) self._perm = perm self._rightmost_transposed_ndims = rightmost_transposed_ndims super(Transpose, self).__init__( forward_min_event_ndims=rightmost_transposed_ndims_, graph_parents=[perm, rightmost_transposed_ndims], is_constant_jacobian=True, validate_args=validate_args, name=name)
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'snaper_hamiltonian_monte_carlo', 'bootstrap_results')): init_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='init_state'), init_state) # It is unfortunate that we need to make this extra call to the TLP here. # The issue is that we need this value to even construct the PHMC, and # the kernel will call this one itself. tlp = self.target_log_prob_fn(*tf.nest.flatten(init_state)) batch_shape = ps.shape(tlp) batch_ndims = ps.rank(tlp) if tf.get_static_value(batch_ndims) is None: # The issue doesn't live in this file, rather it is the downstream # components that fail to work (notably, tfb.Reshape). raise ValueError( 'SNAPERHMC currently requires a statically known ' 'rank of the target log probability.') # We need at least two chains to estimate the principal component. # Number of total chains is local batch size * distributed axis size reduce_chain_axis_names = distribute_lib.canonicalize_named_axis( self.experimental_reduce_chain_axis_names) local_axis_size = ps.maximum(ps.size(tlp), 1) distributed_axis_size = int( ps.reduce_prod([ distribute_lib.get_axis_size(a) for a in reduce_chain_axis_names ])) num_chains = local_axis_size * distributed_axis_size num_chains_ = tf.get_static_value(num_chains) if num_chains_ is not None: if num_chains_ < 2: raise ValueError( 'SNAPERHMC requires at least 2 chains. Got: {}'.format( num_chains_)) elif self.validate_args: with tf.control_dependencies([ assert_util.assert_greater_equal( num_chains, 2, 'SNAPERHMC requires at least 2 chains.') ]): init_state = tf.nest.map_structure(tf.identity, init_state) event_axes = tf.nest.map_structure( lambda x: ps.range(batch_ndims, ps.rank(x)) - ps.rank(x), init_state) if self.experimental_shard_axis_names is None: shard_axis_names = tf.nest.map_structure( lambda _: None, init_state) else: shard_axis_names = self.experimental_shard_axis_names ema_variance = tf.nest.map_structure( lambda x: tf.ones( # pylint: disable=g-long-lambda ps.shape(x)[batch_ndims:], dtype=x.dtype, name='ema_variance'), init_state) ema_mean = tf.nest.map_structure( lambda x: tf.zeros_like(x, name='ema_mean'), ema_variance) ema_principal_component = _normalize(ema_variance, event_axes, shard_axis_names) # These start out at 1 for a bit of smoothing. state_ema_points = tf.ones([], tf.int32) principal_component_ema_points = tf.ones([], tf.int32) kernel = self._make_kernel( batch_shape=batch_shape, step=tf.zeros([], tf.int32), state_ema_points=state_ema_points, state=init_state, mean=ema_mean, variance=ema_variance, principal_component=ema_principal_component, ) inner_results = kernel.bootstrap_results( tf.nest.flatten(init_state)) kernel_results = SNAPERHamiltonianMonteCarloResults( inner_results=inner_results, ema_mean=ema_mean, ema_variance=ema_variance, state_ema_points=state_ema_points, ema_principal_component=ema_principal_component, principal_component_ema_points=principal_component_ema_points, seed=samplers.zeros_seed(), ) return kernel_results
def independent_joint_distribution_from_structure(structure_of_distributions, batch_ndims=None, validate_args=False): """Turns a (potentially nested) structure of dists into a single dist. Args: structure_of_distributions: instance of `tfd.Distribution`, or nested structure (tuple, list, dict, etc.) in which all leaves are `tfd.Distribution` instances. batch_ndims: Optional integer `Tensor` number of leftmost batch dimensions shared across all members of the input structure. If this is specified, the returned joint distribution will be an autobatched distribution with the given batch rank, and all other dimensions absorbed into the event. validate_args: Python `bool`. Whether the joint distribution should validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. Returns: distribution: instance of `tfd.Distribution` such that `distribution.sample()` is equivalent to `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`. If `structure_of_distributions` was indeed a structure (as opposed to a single `Distribution` instance), this will be a `JointDistribution` with the corresponding structure. Raises: TypeError: if any leaves of the input structure are not `tfd.Distribution` instances. """ # If input is already a Distribution, just return it. if dist_util.is_distribution_instance(structure_of_distributions): dist = structure_of_distributions if batch_ndims is not None: excess_ndims = ps.rank_from_shape( dist.batch_shape_tensor()) - batch_ndims if tf.get_static_value( excess_ndims) != 0: # Static value may be None. dist = independent.Independent( dist, reinterpreted_batch_ndims=excess_ndims) return dist # If this structure contains other structures (ie, has elements at depth > 1), # recursively turn them into JDs. element_depths = nest.map_structure_with_tuple_paths( lambda path, x: len(path), structure_of_distributions) if max(tf.nest.flatten(element_depths)) > 1: next_level_shallow_structure = nest.get_traverse_shallow_structure( traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1, structure=element_depths) structure_of_distributions = nest.map_structure_up_to( next_level_shallow_structure, functools.partial(independent_joint_distribution_from_structure, batch_ndims=batch_ndims, validate_args=validate_args), structure_of_distributions) jdnamed = joint_distribution_named.JointDistributionNamed jdsequential = joint_distribution_sequential.JointDistributionSequential # Use an autobatched JD if a specific batch rank was requested. if batch_ndims is not None: jdnamed = functools.partial( joint_distribution_auto_batched.JointDistributionNamedAutoBatched, batch_ndims=batch_ndims, use_vectorized_map=False) jdsequential = functools.partial( joint_distribution_auto_batched. JointDistributionSequentialAutoBatched, batch_ndims=batch_ndims, use_vectorized_map=False) # Otherwise, build a JD from the current structure. if (hasattr(structure_of_distributions, '_asdict') or isinstance(structure_of_distributions, collections.Mapping)): return jdnamed(structure_of_distributions, validate_args=validate_args) return jdsequential(structure_of_distributions, validate_args=validate_args)
def __init__(self, distribution, reinterpreted_batch_ndims=None, validate_args=False, experimental_use_kahan_sum=False, name=None): """Construct an `Independent` distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims which will be regarded as event dims. When `None` all but the first batch axis (batch axis 0) will be transferred to event dimensions (analogous to `tf.layers.flatten`). validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan summation to aggregate independent underlying log_prob values, which improves against the precision of a naive float32 sum. This can be noticeable in particular for large dimensions in float32. See CPU caveat on `tfp.math.reduce_kahan_sum`. name: The name for ops managed by the distribution. Default value: `Independent + distribution.name`. Raises: ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ parameters = dict(locals()) self._experimental_use_kahan_sum = experimental_use_kahan_sum with tf.name_scope(name or ('Independent' + distribution.name)) as name: self._distribution = distribution if reinterpreted_batch_ndims is None: # If possible, statically infer reinterpreted_batch_ndims. batch_ndims = tensorshape_util.rank(distribution.batch_shape) if batch_ndims is not None: self._static_reinterpreted_batch_ndims = max(0, batch_ndims - 1) self._reinterpreted_batch_ndims = tf.convert_to_tensor( self._static_reinterpreted_batch_ndims, dtype_hint=tf.int32, name='reinterpreted_batch_ndims') else: self._reinterpreted_batch_ndims = None self._static_reinterpreted_batch_ndims = None else: self._reinterpreted_batch_ndims = tensor_util.convert_nonref_to_tensor( reinterpreted_batch_ndims, dtype_hint=tf.int32, name='reinterpreted_batch_ndims') static_val = tf.get_static_value(self._reinterpreted_batch_ndims) self._static_reinterpreted_batch_ndims = ( None if static_val is None else int(static_val)) super(Independent, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, name=name)
def make_convolution_transpose_fn_with_dilation(filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, validate_args=False, name=None): """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`. This version tends to be fastest on GPU. It implements the transposed convolution as a regular convolution of an image that is dilated by interleaving rows and columns of zeros equal to the number of strides. Args: filter_shape: ... strides: ... padding: ... rank: ... dilations: ... dtype: ... validate_args: ... name: ... Returns: convolution_transpose_fn: A callable that takes an input `Tensor` and kernel and applies the transpose convolution operation. """ with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): if tf.get_static_value(rank) != 2: raise NotImplementedError( 'Argument `rank` currently only supports `2`; ' 'saw "{}".'.format(rank)) [ filter_shape, rank, strides, padding, dilations, ] = prepare_conv_args(filter_shape, rank=rank, strides=strides, padding=padding, dilations=dilations, is_transpose=True, validate_args=validate_args) sh, sw = strides fh, fw = filter_shape pad_values = [ _get_transpose_conv_dilated_padding(k, stride=s, dilation=d, padding=padding) for (k, s, d) in zip(filter_shape, strides, dilations) ] def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, kernel_shape[-1], batch_shape, event_shape) idx, shape = im2row_index((xh * sh + sum(pad_values[0]), xw * sw + sum(pad_values[1]), c_in), block_shape=filter_shape, slice_step=(1, 1), dilations=dilations, dtype=dtype, transpose=True) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) # Interleave the rows and columns of the input with rows and columns of # zeros equal to the number of strides. x_half_dilated = tf.concat([ tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)], axis=0), dtype=input_dtype), tf.reshape(x, shape=ps.concat( [batch_shape, (xh * xw, 1, c_in)], axis=0)) ], axis=-2) y = tf.reshape(x_half_dilated, shape=ps.concat( [batch_shape, (xh, 1, xw * sw, c_in)], axis=0)) x = tf.reshape(tf.concat([ tf.zeros(ps.concat( [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0), dtype=input_dtype), y ], axis=-3), shape=ps.concat( [batch_shape, (xh * sh, xw * sw, c_in)], axis=0)) x_pad = tf.pad(x, paddings=paddings, constant_values=0) flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape), indices=idx, axis=-1) im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) return tf.matmul(im_x, kernel[..., tf.newaxis, :, :]) return op
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, kernel_shape[-1], batch_shape, event_shape) idx, shape = im2row_index((xh * sh + sum(pad_values[0]), xw * sw + sum(pad_values[1]), c_in), block_shape=filter_shape, slice_step=(1, 1), dilations=dilations, dtype=dtype, transpose=True) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) # Interleave the rows and columns of the input with rows and columns of # zeros equal to the number of strides. x_half_dilated = tf.concat([ tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)], axis=0), dtype=input_dtype), tf.reshape(x, shape=ps.concat( [batch_shape, (xh * xw, 1, c_in)], axis=0)) ], axis=-2) y = tf.reshape(x_half_dilated, shape=ps.concat( [batch_shape, (xh, 1, xw * sw, c_in)], axis=0)) x = tf.reshape(tf.concat([ tf.zeros(ps.concat( [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0), dtype=input_dtype), y ], axis=-3), shape=ps.concat( [batch_shape, (xh * sh, xw * sw, c_in)], axis=0)) x_pad = tf.pad(x, paddings=paddings, constant_values=0) flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape), indices=idx, axis=-1) im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
def make_convolution_fn(filter_shape, rank, strides, padding, dilations=None, dtype=tf.int32, validate_args=False, name=None): """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" with tf.name_scope(name or 'conv2d'): if tf.get_static_value(rank) != 2: raise NotImplementedError( 'Argument `rank` currently only supports `2`; ' 'saw "{}".'.format(rank)) [ filter_shape, rank, strides, padding, dilations, ] = prepare_conv_args(filter_shape, rank=rank, strides=strides, padding=padding, dilations=dilations, validate_args=validate_args) def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) fh, fw = filter_shape assertions = _maybe_validate_input_shapes(ps.shape(kernel), channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): if tf.get_static_value(ps.rank(kernel)) == 2: flat_x = tf.reshape(x, shape=ps.concat([[-1], event_shape], axis=0)) flat_y = tf.nn.conv2d(x, filters=tf.reshape( kernel, shape=[fh, fw, c_in, -1]), strides=strides, padding=padding, data_format='NHWC', dilations=dilations) output_shape = ps.shape(flat_y)[-3:] return tf.reshape(flat_y, shape=ps.concat([batch_shape, output_shape], axis=0)) pad_values = [ _get_conv_padding(xdim, filter_dim=k, stride=s, dilation=d, padding=padding) for (xdim, k, s, d) in zip((xh, xw), filter_shape, strides, dilations) ] idx, shape = im2row_index( (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in), block_shape=filter_shape, slice_step=strides, dilations=dilations, dtype=dtype) if padding == 'SAME': n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) x = tf.pad(x, paddings=paddings, constant_values=0) flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.gather(tf.reshape(x, shape=flat_shape), indices=idx, axis=-1) im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) return tf.matmul(im_x, kernel[..., tf.newaxis, :, :]) return op
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) fh, fw = filter_shape assertions = _maybe_validate_input_shapes(ps.shape(kernel), channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): if tf.get_static_value(ps.rank(kernel)) == 2: flat_x = tf.reshape(x, shape=ps.concat([[-1], event_shape], axis=0)) flat_y = tf.nn.conv2d(x, filters=tf.reshape( kernel, shape=[fh, fw, c_in, -1]), strides=strides, padding=padding, data_format='NHWC', dilations=dilations) output_shape = ps.shape(flat_y)[-3:] return tf.reshape(flat_y, shape=ps.concat([batch_shape, output_shape], axis=0)) pad_values = [ _get_conv_padding(xdim, filter_dim=k, stride=s, dilation=d, padding=padding) for (xdim, k, s, d) in zip((xh, xw), filter_shape, strides, dilations) ] idx, shape = im2row_index( (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in), block_shape=filter_shape, slice_step=strides, dilations=dilations, dtype=dtype) if padding == 'SAME': n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) x = tf.pad(x, paddings=paddings, constant_values=0) flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.gather(tf.reshape(x, shape=flat_shape), indices=idx, axis=-1) im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
def _batch_shape(self): return tf.nest.map_structure( lambda b: tensorshape_util.concatenate( # pylint: disable=g-long-lambda [tf.get_static_value(self.num_particles)], b), self.distribution.batch_shape)
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = prefer_static.rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = distribution_util.rotate_transpose(x, shift) if center: x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = prefer_static.shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = tf.cast(x_len, np.float64) target_length = tf.pow( np.float64(2.), tf.math.ceil(tf.math.log(x_len_float64 * 2) / np.log(2.))) pad_length = tf.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = distribution_util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype_util.is_complex(dtype): if not dtype_util.is_floating(dtype): raise TypeError( 'Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex( x_rotated_pad, dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not tensorshape_util.is_fully_defined(x_rotated.shape): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = tensorshape_util.as_list(x_rotated.shape) chopped_shape[-1] = min(x_len, max_lags + 1) shifted_product_chopped.set_shape(chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = tf.cast(x_len, dtype_util.real_dtype(dtype)) max_lags = tf.cast(max_lags, dtype_util.real_dtype(dtype)) denominator = x_len - tf.range(0., max_lags + 1.) denominator = tf.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return distribution_util.rotate_transpose(shifted_product_rotated, -shift)
def testCopy(self): # 5 random index points in R^2 index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32) # 10 random index points in R^2 index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32) observation_index_points_1 = (np.random.uniform( -4., 4., (7, 2)).astype(np.float32)) observation_index_points_2 = (np.random.uniform( -4., 4., (9, 2)).astype(np.float32)) observations_1 = np.random.uniform(-1., 1., 7).astype(np.float32) observations_2 = np.random.uniform(-1., 1., 9).astype(np.float32) # ==> shape = [6, 25, 2] if not self.is_static: index_points_1 = tf1.placeholder_with_default(index_points_1, shape=None) index_points_2 = tf1.placeholder_with_default(index_points_2, shape=None) observation_index_points_1 = tf1.placeholder_with_default( observation_index_points_1, shape=None) observation_index_points_2 = tf1.placeholder_with_default( observation_index_points_2, shape=None) observations_1 = tf1.placeholder_with_default(observations_1, shape=None) observations_2 = tf1.placeholder_with_default(observations_2, shape=None) mean_fn = lambda x: np.array([0.], np.float32) kernel_1 = psd_kernels.ExponentiatedQuadratic() kernel_2 = psd_kernels.ExpSinSquared() gprm1 = tfd.GaussianProcessRegressionModel( kernel=kernel_1, index_points=index_points_1, observation_index_points=observation_index_points_1, observations=observations_1, mean_fn=mean_fn, jitter=1e-5, validate_args=True) gprm2 = gprm1.copy(kernel=kernel_2, index_points=index_points_2, observation_index_points=observation_index_points_2, observations=observations_2) precomputed_gprm1 = ( tfd.GaussianProcessRegressionModel.precompute_regression_model( kernel=kernel_1, index_points=index_points_1, observation_index_points=observation_index_points_1, observations=observations_1, mean_fn=mean_fn, jitter=1e-5, validate_args=True)) precomputed_gprm2 = precomputed_gprm1.copy(index_points=index_points_2) self.assertIs(precomputed_gprm1.mean_fn, precomputed_gprm2.mean_fn) self.assertIs(precomputed_gprm1.kernel, precomputed_gprm2.kernel) event_shape_1 = [5] event_shape_2 = [10] self.assertIsInstance(gprm1.kernel.base_kernel, psd_kernels.ExponentiatedQuadratic) self.assertIsInstance(gprm2.kernel.base_kernel, psd_kernels.ExpSinSquared) if self.is_static or tf.executing_eagerly(): self.assertAllEqual(gprm1.batch_shape, gprm2.batch_shape) self.assertAllEqual(gprm1.event_shape, event_shape_1) self.assertAllEqual(gprm2.event_shape, event_shape_2) self.assertAllEqual(gprm1.index_points, index_points_1) self.assertAllEqual(gprm2.index_points, index_points_2) self.assertAllEqual(tf.get_static_value(gprm1.jitter), tf.get_static_value(gprm2.jitter)) else: self.assertAllEqual(self.evaluate(gprm1.batch_shape_tensor()), self.evaluate(gprm2.batch_shape_tensor())) self.assertAllEqual(self.evaluate(gprm1.event_shape_tensor()), event_shape_1) self.assertAllEqual(self.evaluate(gprm2.event_shape_tensor()), event_shape_2) self.assertEqual(self.evaluate(gprm1.jitter), self.evaluate(gprm2.jitter)) self.assertAllEqual(self.evaluate(gprm1.index_points), index_points_1) self.assertAllEqual(self.evaluate(gprm2.index_points), index_points_2)
def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-builtin s = _shape(input) s_ = tf.get_static_value(s) if s_ is not None: return np.zeros(s, _numpy_dtype(dtype or input.dtype)) return tf.zeros(s, dtype or s.dtype, name)
def make_convolution_transpose_fn_with_subkernels_matrix( filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, validate_args=False, name=None): """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): if tf.get_static_value(rank) != 2: raise NotImplementedError( 'Argument `rank` currently only supports `2`; ' 'saw "{}".'.format(rank)) strides = tf.get_static_value(strides) if not isinstance(strides, int): raise ValueError( 'Argument `strides` must be a statically known integer.' 'Saw: {}'.format(strides)) [ filter_shape, rank, _, padding, dilations, ] = prepare_conv_args(filter_shape, rank=rank, strides=strides, padding=padding, dilations=dilations, is_transpose=True, validate_args=validate_args) fh, fw = filter_shape dh, dw = dilations # Determine maximum filter height and filter width of sub-kernels. sub_fh = (fh - 1) // strides + 1 sub_fw = (fw - 1) // strides + 1 def loop_body(i_, event_ind): i = i_ // strides j = i_ % strides i_ind = ps.range(i * fw, fw * fh, delta=strides * fw, dtype=dtype) j_ind = ps.range(j, fw, delta=strides, dtype=dtype) nc = cartesian_add([i_ind, j_ind]) ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0]) k = ps.reshape(cartesian_add([ ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype), ps.range(ps.shape(nc)[1], dtype=dtype) ]), shape=[-1]) last_j = strides - (fw - j - 1) % strides - 1 last_i = strides - (fh - i - 1) % strides - 1 kernel_ind = ps.stack( [k, ps.ones_like(k) * last_i * strides + last_j], axis=1) event_ind = ps.tensor_scatter_nd_update(event_ind, ind[..., tf.newaxis], kernel_ind) return i_ + 1, event_ind event_ind = ps.zeros((fh * fw, 2), dtype=dtype) _, event_ind = tf.while_loop(lambda i, _: i < strides**2, loop_body, [tf.zeros([], dtype=dtype), event_ind]) tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding( fh, stride=strides, dilation=dh, padding=padding) tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding( fw, stride=strides, dilation=dw, padding=padding) pad_bottom = (tot_pad_bottom - 1) // strides + 1 pad_top = (tot_pad_top - 1) // strides + 1 pad_right = (tot_pad_right - 1) // strides + 1 pad_left = (tot_pad_left - 1) // strides + 1 padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right)) truncate_top = pad_top * strides - tot_pad_top truncate_left = pad_left * strides - tot_pad_left def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel=kernel, filter_shape=filter_shape, strides=(strides, ) * rank, padding=padding, dilations=dilations, c_out=c_out, batch_shape=batch_shape, event_shape=event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) x_pad_shape = ps.shape(x_pad)[:-3] flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_pad, shape=flat_shape) idx, s = im2row_index( (xh + tf.reduce_sum(padding_vals[0]), xw + tf.reduce_sum(padding_vals[1]), c_in), block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations) x_ = tf.gather(flat_x, indices=idx, axis=-1) im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0)) # Add channels to subkernel indices idx_event = event_ind * [[c_in, 1]] idx_event_channels = (idx_event[tf.newaxis] + tf.stack( [ps.range(c_in), tf.zeros( (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :]) idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0) idx_event_broadcast = tf.broadcast_to( idx_event, shape=ps.concat( [kernel_batch, ps.shape(idx_event)], axis=0)) # Add cartesian product of batch indices, since scatter_nd can only be # applied to leading dimensions. idx_batch = tf.stack(tf.meshgrid(*[ ps.range(b_, delta=1, dtype=dtype) for b_ in tf.unstack(kernel_batch) ], indexing='ij'), axis=ps.size(kernel_batch)) idx_batch = tf.cast(idx_batch, dtype=dtype) # empty tensor is float idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros( (ps.shape(idx_event)[0], 1), dtype=dtype) idx_kernel = tf.concat( [idx_batch_broadcast, idx_event_broadcast], axis=-1) kernel_mat = tf.scatter_nd( idx_kernel, updates=kernel, shape=ps.cast(ps.concat([ kernel_batch, [sub_fh * sub_fw * c_in, strides**2, c_out] ], axis=0), dtype=dtype)) kernel_mat = tf.reshape( kernel_mat, shape=ps.concat( [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]], axis=0)) kernel_mat = kernel_mat[..., tf.newaxis, :, :] out = tf.matmul(im_x, kernel_mat) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) if strides > 1: tot_size = tf.reduce_prod(broadcast_batch_shape) flat_out = tf.reshape(out, shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0)) out = tf.nn.depth_to_space(flat_out, block_size=strides) if padding == 'VALID': out_height = fh + strides * (xh - 1) out_width = fw + strides * (xw - 1) elif padding == 'SAME': out_height = xh * strides out_width = xw * strides out = out[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :] out = tf.reshape( out, shape=ps.concat([ broadcast_batch_shape, [out_height, out_width, c_out] ], axis=0)) return out return op
def index_remapping_gather(params, indices, axis=0, indices_axis=0, name='index_remapping_gather'): """Gather values from `axis` of `params` using `indices_axis` of `indices`. The shape of `indices` must broadcast to that of `params` when their `indices_axis` and `axis` (respectively) are aligned: ```python # params.shape: [p[0], ..., ..., p[axis], ..., ..., p[rank(params)] - 1]) # indices.shape: [i[0], ..., i[indices_axis], ..., i[rank(indices)] - 1]) ``` In particular, `params` must have at least as many leading dimensions as `indices` (`axis >= indices_axis`), and at least as many trailing dimensions (`rank(params) - axis >= rank(indices) - indices_axis`). The `result` has the same shape as `params`, except that the dimension of size `p[axis]` is replaced by one of size `i[indices_axis]`: ```python # result.shape: [p[0], ..., ..., i[indices_axis], ..., ..., p[rank(params) - 1]] ``` In the case where `rank(params) == 5`, `rank(indices) == 3`, `axis = 2`, and `indices_axis = 1`, the result is given by ```python # alignment is: v axis # params.shape == [p[0], p[1], p[2], p[3], p[4]] # indices.shape == [i[0], i[1], i[2]] # ^ indices_axis result[i, j, k, l, m] = params[i, j, indices[j, k, l], l, m] ``` Args: params: `N-D` `Tensor` (`N > 0`) from which to gather values. Number of dimensions must be known statically. indices: `Tensor` with values in `{0, ..., params.shape[axis] - 1}`, whose shape broadcasts to that of `params` as described above. axis: Python `int` axis of `params` from which to gather. indices_axis: Python `int` axis of `indices` to align with the `axis` over which `params` is gathered. name: String name for scoping created ops. Returns: `Tensor` composed of elements of `params`. Raises: ValueError: If shape/rank requirements are not met. """ with tf.name_scope(name): params = tf.convert_to_tensor(params, name='params') indices = tf.convert_to_tensor(indices, name='indices') params_ndims = tensorshape_util.rank(params.shape) indices_ndims = tensorshape_util.rank(indices.shape) # `axis` dtype must match ndims, which are 64-bit Python ints. axis = tf.get_static_value(ps.convert_to_shape_tensor(axis, dtype=tf.int64)) indices_axis = tf.get_static_value( ps.convert_to_shape_tensor(indices_axis, dtype=tf.int64)) if params_ndims is None: raise ValueError( 'Rank of `params`, must be known statically. This is due to ' 'tf.gather not accepting a `Tensor` for `batch_dims`.') if axis is None: raise ValueError( '`axis` must be known statically. This is due to ' 'tf.gather not accepting a `Tensor` for `batch_dims`.') if indices_axis is None: raise ValueError( '`indices_axis` must be known statically. This is due to ' 'tf.gather not accepting a `Tensor` for `batch_dims`.') if indices_axis > axis: raise ValueError( '`indices_axis` should be <= `axis`, but was {} > {}'.format( indices_axis, axis)) if params_ndims < 1: raise ValueError( 'Rank of params should be `> 0`, but was {}'.format(params_ndims)) if indices_ndims is not None and indices_ndims < 1: raise ValueError( 'Rank of indices should be `> 0`, but was {}'.format(indices_ndims)) if (indices_ndims is not None and (indices_ndims - indices_axis > params_ndims - axis)): raise ValueError( '`rank(params) - axis` ({} - {}) must be >= `rank(indices) - ' 'indices_axis` ({} - {}), but was not.'.format( params_ndims, axis, indices_ndims, indices_axis)) # `tf.gather` requires the axis to be the rightmost batch ndim. So, we # transpose `indices_axis` to be the rightmost dimension of `indices`... transposed_indices = dist_util.move_dimension(indices, source_idx=indices_axis, dest_idx=-1) # ... and `axis` to be the corresponding (aligned as in the docstring) # dimension of `params`. broadcast_indices_ndims = indices_ndims + (axis - indices_axis) transposed_params = dist_util.move_dimension( params, source_idx=axis, dest_idx=broadcast_indices_ndims - 1) # Next we broadcast `indices` so that its shape has the same prefix as # `params.shape`. transposed_params_shape = ps.shape(transposed_params) result_shape = ps.concat([ transposed_params_shape[:broadcast_indices_ndims - 1], ps.shape(indices)[indices_axis:indices_axis + 1], transposed_params_shape[broadcast_indices_ndims:]], axis=0) broadcast_indices = ps.broadcast_to( transposed_indices, result_shape[:broadcast_indices_ndims]) result_t = tf.gather(transposed_params, broadcast_indices, batch_dims=broadcast_indices_ndims - 1, axis=broadcast_indices_ndims - 1) return dist_util.move_dimension(result_t, source_idx=broadcast_indices_ndims - 1, dest_idx=axis)
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel=kernel, filter_shape=filter_shape, strides=(strides, ) * rank, padding=padding, dilations=dilations, c_out=c_out, batch_shape=batch_shape, event_shape=event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) x_pad_shape = ps.shape(x_pad)[:-3] flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_pad, shape=flat_shape) idx, s = im2row_index( (xh + tf.reduce_sum(padding_vals[0]), xw + tf.reduce_sum(padding_vals[1]), c_in), block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations) x_ = tf.gather(flat_x, indices=idx, axis=-1) im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0)) # Add channels to subkernel indices idx_event = event_ind * [[c_in, 1]] idx_event_channels = (idx_event[tf.newaxis] + tf.stack( [ps.range(c_in), tf.zeros( (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :]) idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0) idx_event_broadcast = tf.broadcast_to( idx_event, shape=ps.concat( [kernel_batch, ps.shape(idx_event)], axis=0)) # Add cartesian product of batch indices, since scatter_nd can only be # applied to leading dimensions. idx_batch = tf.stack(tf.meshgrid(*[ ps.range(b_, delta=1, dtype=dtype) for b_ in tf.unstack(kernel_batch) ], indexing='ij'), axis=ps.size(kernel_batch)) idx_batch = tf.cast(idx_batch, dtype=dtype) # empty tensor is float idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros( (ps.shape(idx_event)[0], 1), dtype=dtype) idx_kernel = tf.concat( [idx_batch_broadcast, idx_event_broadcast], axis=-1) kernel_mat = tf.scatter_nd( idx_kernel, updates=kernel, shape=ps.cast(ps.concat([ kernel_batch, [sub_fh * sub_fw * c_in, strides**2, c_out] ], axis=0), dtype=dtype)) kernel_mat = tf.reshape( kernel_mat, shape=ps.concat( [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]], axis=0)) kernel_mat = kernel_mat[..., tf.newaxis, :, :] out = tf.matmul(im_x, kernel_mat) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) if strides > 1: tot_size = tf.reduce_prod(broadcast_batch_shape) flat_out = tf.reshape(out, shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0)) out = tf.nn.depth_to_space(flat_out, block_size=strides) if padding == 'VALID': out_height = fh + strides * (xh - 1) out_width = fw + strides * (xw - 1) elif padding == 'SAME': out_height = xh * strides out_width = xw * strides out = out[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :] out = tf.reshape( out, shape=ps.concat([ broadcast_batch_shape, [out_height, out_width, c_out] ], axis=0)) return out
def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [tf.get_static_value(x) for x in args] if any(vec is None for vec in args_): return tf.concat(args, axis=0) return [val for vec in args_ for val in vec] # pylint: disable=g-complex-comprehension
def make_convolution_transpose_fn_with_subkernels(filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, validate_args=False, name=None): """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): if tf.get_static_value(rank) != 2: raise NotImplementedError( 'Argument `rank` currently only supports `2`; ' 'saw "{}".'.format(rank)) [ filter_shape, rank, strides, padding, dilations, ] = prepare_conv_args(filter_shape, rank=rank, strides=strides, padding=padding, dilations=dilations, is_transpose=True, validate_args=validate_args) sh, sw = strides fh, fw = filter_shape dh, dw = dilations # Determine maximum filter height and filter width of sub-kernels. sub_fh = (fh - 1) // sh + 1 sub_fw = (fw - 1) // sw + 1 def loop_body(i_, kernels_ind): i = i_ // sw j = i_ % sw i_ind = ps.range((sh - i - 1) * fw, fw * fh, delta=sh * fw, dtype=dtype) j_ind = ps.range((sw - j - 1), fw, delta=sw, dtype=dtype) last_j = sw - (fw - j - 1) % sw - 1 last_i = sh - (fh - i - 1) % sh - 1 pos = last_i * sw + last_j nc = cartesian_add([i_ind, j_ind]) kernels_ind = kernels_ind.write( sh * sw - pos - 1, ps.reverse(ps.reverse(nc, [0]), [1])) return i_ + 1, kernels_ind kernels_ind = tf.TensorArray(dtype=dtype, infer_shape=False, size=1, dynamic_size=True) _, kernels_ind = tf.while_loop(lambda i, _: i < sh * sw, loop_body, [0, kernels_ind]) tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding( fh, stride=sh, dilation=dh, padding=padding) tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding( fw, stride=sw, dilation=dw, padding=padding) pad_bottom = (tot_pad_bottom - 1) // sh + 1 pad_top = (tot_pad_top - 1) // sh + 1 pad_right = (tot_pad_right - 1) // sw + 1 pad_left = (tot_pad_left - 1) // sw + 1 padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right)) truncate_top = pad_top * sh - tot_pad_top truncate_left = pad_left * sw - tot_pad_left def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, c_out, batch_shape, event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1 ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1 def loop_body(i, outputs): subkernel_ind = kernels_ind.read(i) fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2) eh = ex_h + fh_ - 1 ew = ex_w + fw_ - 1 subkernel_ind = ps.reshape(ps.reshape( subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] + ps.range(c_in), shape=[-1]) k = tf.gather(kernel, subkernel_ind, axis=-2) ind, shape = im2row_index([eh, ew, c_in], block_shape=(fh_, fw_), slice_step=(1, 1), dilations=dilations) x_i = x_pad[..., :eh, :ew, :] x_i_shape = ps.shape(x_i) flat_shape = ps.pad(x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_i, flat_shape) x_ = tf.gather(flat_x, ind, axis=-1) im_x = tf.reshape( x_, ps.concat([x_i_shape[:-3], shape], axis=0)) outputs = outputs.write( i, tf.matmul( im_x, tf.reshape( k, ps.concat([ kernel_batch, [1, fh_ * fw_ * c_in, c_out] ], axis=0)))) return i + 1, outputs outputs = tf.TensorArray(dtype=input_dtype, infer_shape=False, size=1, dynamic_size=True) _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body, [0, outputs]) y = outputs.concat() m = tf.reduce_prod(ps.shape(y)[:-3]) y_ = tf.reshape(y, shape=ps.concat([[m], ps.shape(y)[-3:]], axis=0)) y2 = tf.batch_to_space(y_, strides, crops=tf.zeros([2, 2], dtype=tf.int64)) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) y2 = tf.reshape( y2, ps.concat([broadcast_batch_shape, ps.shape(y2)[-3:]], axis=0)) if padding == 'VALID': out_height = fh + sh * (xh - 1) out_width = fw + sw * (xw - 1) elif padding == 'SAME': out_height = xh * sh out_width = xw * sw return y2[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :] return op
def _replace_event_shape_in_tensorshape(input_tensorshape, event_shape_in, event_shape_out): """Replaces the event shape dims of a `TensorShape`. Args: input_tensorshape: a `TensorShape` instance in which to attempt replacing event shape. event_shape_in: `Tensor` shape representing the event shape expected to be present in (rightmost dims of) `tensorshape_in`. Must be compatible with the rightmost dims of `tensorshape_in`. event_shape_out: `Tensor` shape representing the new event shape, i.e., the replacement of `event_shape_in`, Returns: output_tensorshape: `TensorShape` with the rightmost `event_shape_in` replaced by `event_shape_out`. Might be partially defined, i.e., `TensorShape(None)`. is_validated: Python `bool` indicating static validation happened. Raises: ValueError: if we can determine the event shape portion of `tensorshape_in` as well as `event_shape_in` both statically, and they are not compatible. "Compatible" here means that they are identical on any dims that are not -1 in `event_shape_in`. """ event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape) if tensorshape_util.rank( input_tensorshape) is None or event_shape_in_ndims is None: return tf.TensorShape(None), False # Not is_validated. input_non_event_ndims = tensorshape_util.rank( input_tensorshape) - event_shape_in_ndims if input_non_event_ndims < 0: raise ValueError( 'Input has lower rank ({}) than `event_shape_ndims` ({}).'.format( tensorshape_util.rank(input_tensorshape), event_shape_in_ndims)) input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims] input_event_tensorshape = input_tensorshape[input_non_event_ndims:] # Check that `input_event_shape_` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. event_shape_in_ = tf.get_static_value(event_shape_in) is_validated = (tensorshape_util.is_fully_defined(input_event_tensorshape) and event_shape_in_ is not None) if is_validated: input_event_shape_ = np.int32(input_event_tensorshape) mask = event_shape_in_ >= 0 explicit_input_event_shape_ = input_event_shape_[mask] explicit_event_shape_in_ = event_shape_in_[mask] if not np.all(explicit_input_event_shape_ == explicit_event_shape_in_): raise ValueError( 'Input `event_shape` does not match `event_shape_in` ' '({} vs {}).'.format(input_event_shape_, event_shape_in_)) event_tensorshape_out = tensorshape_util.constant_value_as_shape( event_shape_out) if tensorshape_util.rank(event_tensorshape_out) is None: output_tensorshape = tf.TensorShape(None) else: output_tensorshape = tensorshape_util.concatenate( input_non_event_tensorshape, event_tensorshape_out) return output_tensorshape, is_validated
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, c_out, batch_shape, event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1 ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1 def loop_body(i, outputs): subkernel_ind = kernels_ind.read(i) fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2) eh = ex_h + fh_ - 1 ew = ex_w + fw_ - 1 subkernel_ind = ps.reshape(ps.reshape( subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] + ps.range(c_in), shape=[-1]) k = tf.gather(kernel, subkernel_ind, axis=-2) ind, shape = im2row_index([eh, ew, c_in], block_shape=(fh_, fw_), slice_step=(1, 1), dilations=dilations) x_i = x_pad[..., :eh, :ew, :] x_i_shape = ps.shape(x_i) flat_shape = ps.pad(x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_i, flat_shape) x_ = tf.gather(flat_x, ind, axis=-1) im_x = tf.reshape( x_, ps.concat([x_i_shape[:-3], shape], axis=0)) outputs = outputs.write( i, tf.matmul( im_x, tf.reshape( k, ps.concat([ kernel_batch, [1, fh_ * fw_ * c_in, c_out] ], axis=0)))) return i + 1, outputs outputs = tf.TensorArray(dtype=input_dtype, infer_shape=False, size=1, dynamic_size=True) _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body, [0, outputs]) y = outputs.concat() m = tf.reduce_prod(ps.shape(y)[:-3]) y_ = tf.reshape(y, shape=ps.concat([[m], ps.shape(y)[-3:]], axis=0)) y2 = tf.batch_to_space(y_, strides, crops=tf.zeros([2, 2], dtype=tf.int64)) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) y2 = tf.reshape( y2, ps.concat([broadcast_batch_shape, ps.shape(y2)[-3:]], axis=0)) if padding == 'VALID': out_height = fh + sh * (xh - 1) out_width = fw + sw * (xw - 1) elif padding == 'SAME': out_height = xh * sh out_width = xw * sw return y2[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :]
def __init__(self, output_shape=(32, 32, 3), num_glow_blocks=3, num_steps_per_block=32, coupling_bijector_fn=None, exit_bijector_fn=None, grab_after_block=None, use_actnorm=True, seed=None, validate_args=False, name='glow'): """Creates the Glow bijector. Args: output_shape: A list of integers, specifying the event shape of the output, of the bijectors forward pass (the image). Specified as [H, W, C]. Default Value: (32, 32, 3) num_glow_blocks: An integer, specifying how many downsampling levels to include in the model. This must divide equally into both H and W, otherwise the bijector would not be invertible. Default Value: 3 num_steps_per_block: An integer specifying how many Affine Coupling and 1x1 convolution layers to include at each level of the spatial hierarchy. Default Value: 32 (i.e. the value used in the original glow paper). coupling_bijector_fn: A function which takes the argument `input_shape` and returns a callable neural network (e.g. a keras.Sequential). The network should either return a tensor with the same event shape as `input_shape` (this will employ additive coupling), a tensor with the same height and width as `input_shape` but twice the number of channels (this will employ affine coupling), or a bijector which takes in a tensor with event shape `input_shape`, and returns a tensor with shape `input_shape`. exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is a function which takes the argument `input_shape` and `output_chan` and returns a callable neural network. The neural network it returns should take a tensor of shape `input_shape` as the input, and return one of three options: A tensor with `output_chan` channels, a tensor with `2 * output_chan` channels, or a bijector. Additional details can be found in the documentation for ExitBijector. grab_after_block: A tuple of floats, specifying what fraction of the remaining channels to remove following each glow block. Glow will take the integer floor of this number multiplied by the remaining number of channels. The default is half at each spatial hierarchy. Default value: None (this will take out half of the channels after each block. use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent initialization is used to initialize this layer. Default value: `False` seed: A seed to control randomness in the 1x1 convolution initialization. Default value: `None` (i.e., non-reproducible sampling). validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` name: Python `str`, name given to ops managed by this object. Default value: `'glow'`. """ # Make sure that the input shape is fully defined. if not tensorshape_util.is_fully_defined(output_shape): raise ValueError('Shape must be fully defined.') if tensorshape_util.rank(output_shape) != 3: raise ValueError('Shape ndims must be 3 for images. Your shape is' '{}'.format(tensorshape_util.rank(output_shape))) num_glow_blocks_ = tf.get_static_value(num_glow_blocks) if (num_glow_blocks_ is None or int(num_glow_blocks_) != num_glow_blocks_ or num_glow_blocks_ < 1): raise ValueError('Argument `num_glow_blocks` must be a statically known' 'positive `int` (saw: {}).'.format(num_glow_blocks)) num_glow_blocks = int(num_glow_blocks_) output_shape = tensorshape_util.as_list(output_shape) h, w, c = output_shape n = num_glow_blocks nsteps = num_steps_per_block # Default Glow: Half of the channels are split off after each block, # and after the final block, no channels are split off. if grab_after_block is None: grab_after_block = tuple([0.5] * (n - 1) + [0.]) # Thing we know must be true: h and w are evenly divisible by 2, n times. # Otherwise, the squeeze bijector will not work. if w % 2**n != 0: raise ValueError('Width must be divisible by 2 at least n times.' 'Saw: {} % {} != 0'.format(w, 2**n)) if h % 2**n != 0: raise ValueError('Height should be divisible by 2 at least n times.') if h // 2**n < 1: raise ValueError('num_glow_blocks ({0}) is too large. The image height ' '({1}) must be divisible by 2 no more than {2} ' 'times.'.format(num_glow_blocks, h, int(np.log(h) / np.log(2.)))) if w // 2**n < 1: raise ValueError('num_glow_blocks ({0}) is too large. The image width ' '({1}) must be divisible by 2 no more than {2} ' 'times.'.format(num_glow_blocks, w, int(np.log(h) / np.log(2.)))) # Other things we want to be true: # - The number of times we take must be equal to the number of glow blocks. if len(grab_after_block) != num_glow_blocks: raise ValueError('Length of grab_after_block ({0}) must match the number' 'of blocks ({1}).'.format(len(grab_after_block), num_glow_blocks)) self._blockwise_splits = self._get_blockwise_splits(output_shape, grab_after_block[::-1]) # Now check on the values of blockwise splits if any([bs[0] < 1 for bs in self._blockwise_splits]): first_offender = [bs[0] for bs in self._blockwise_splits].index(True) raise ValueError('At at least one exit, you are taking out all of your ' 'channels, and therefore have no inputs to later blocks.' ' Try setting grab_after_block to a lower value at index' '{}.'.format(first_offender)) if any(np.isclose(gab, 0) for gab in grab_after_block): # Special case: if specifically exiting no channels, then the exit is # just an identity bijector. pass elif any([bs[1] < 1 for bs in self._blockwise_splits]): first_offender = [bs[1] for bs in self._blockwise_splits].index(True) raise ValueError('At least one of your layers has < 1 output channels. ' 'This means you set grab_at_block too small. ' 'Try setting grab_after_block to a larger value at index' '{}.'.format(first_offender)) # Lets start to build our bijector. We assume that the distribution is 1 # dimensional. First, lets reshape it to an image. glow_chain = [ reshape.Reshape( event_shape_out=[h // 2**n, w // 2**n, c * 4**n], event_shape_in=[h * w * c]) ] seedstream = SeedStream(seed=seed, salt='random_beta') for i in range(n): # This is the shape of the current tensor current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1)) # This is the shape of the input to both the glow block and exit bijector. this_nchan = sum(self._blockwise_splits[i][0:2]) this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan) glow_chain.append(invert.Invert(ExitBijector(current_shape, self._blockwise_splits[i], exit_bijector_fn))) glow_block = GlowBlock(input_shape=this_input_shape, num_steps=nsteps, coupling_bijector_fn=coupling_bijector_fn, use_actnorm=use_actnorm, seedstream=seedstream) if self._blockwise_splits[i][2] == 0: # All channels are passed to the RealNVP glow_chain.append(glow_block) else: # Some channels are passed around the block. # This is done with the Blockwise bijector. glow_chain.append( blockwise.Blockwise( [glow_block, identity.Identity()], [sum(self._blockwise_splits[i][0:2]), self._blockwise_splits[i][2]])) # Finally, lets expand the channels into spatial features. glow_chain.append( Expand(input_shape=[ h // 2**n * 2**i, w // 2**n * 2**i, c * 4**n // 4**i, ])) glow_chain = glow_chain[::-1] # To finish off, we initialize the bijector with the chain we've built # This way, the rest of the model attributes are taken care of for us. super(Glow, self).__init__( bijectors=glow_chain, validate_args=validate_args, name=name)
def im2row_index(input_shape, block_shape, rank=2, slice_step=(1, 1), dilations=(1, 1), dtype=tf.int32, transpose=False, validate_args=False, name=None): """Computes indexes into a flattened image for building `im2row`.""" with tf.name_scope(name or 'im2row_index'): if tf.get_static_value(rank) != 2: raise NotImplementedError( 'Argument `rank` currently only supports `2`; ' 'saw "{}".'.format(rank)) fh, fw = prepare_tuple_argument(block_shape, n=rank, arg_name='block_shape', validate_args=validate_args) sh, sw = prepare_tuple_argument(slice_step, n=rank, arg_name='slice_step', validate_args=validate_args) dh, dw = prepare_tuple_argument(dilations, n=rank, arg_name='dilations', validate_args=validate_args) # 1) Process input arguments. batch_shape, h, w, c = ps.split(ps.reshape(ps.cast(input_shape, dtype=dtype), shape=[-1]), num_or_size_splits=[-1, 1, 1, 1]) h, w, c = h[0], w[0], c[0] tot_fh = dh * (fh - 1) + 1 tot_fw = dw * (fw - 1) + 1 # 2) Assemble all block start positions as indexes into the flattened image. # start_idx.shape = [fh, fw, c] if transpose: last_element = lambda size, step: size - (size - 1) % step - 1 w_step = c * dw h_step = c * w * dh last_w = last_element(c * tot_fw, w_step) last_h = last_element(c * w * tot_fh, h_step) start_idx = cartesian_add([ ps.range(last_h, -1, delta=-h_step, dtype=dtype), ps.range(last_w, -1, delta=-w_step, dtype=dtype), ps.range(c, delta=1, dtype=dtype), ]) else: start_idx = cartesian_add([ ps.range(c * w * tot_fh, delta=c * w * dh, dtype=dtype), ps.range(c * tot_fw, delta=c * dw, dtype=dtype), ps.range(c, delta=1, dtype=dtype), ]) # 3) Assemble all block offsets (into flattened image). eh = h - tot_fh + 1 ew = w - tot_fw + 1 offset_idx = cartesian_add([ ps.range(w * eh, delta=w * sh, dtype=dtype), ps.range(ew, delta=sw, dtype=dtype), ]) offset_idx = offset_idx * c oh = (eh - 1) // sh + 1 # out height ow = (ew - 1) // sw + 1 # out width # 4) Combine block start/offset pairs. # shape = [(eh // sh) * (ew // sw), fh * fw * c] idx = cartesian_add([offset_idx, start_idx]) new_shape = ps.concat( [batch_shape, ps.convert_to_shape_tensor([oh, ow, fh * fw * c])], axis=0) return idx, new_shape
def __init__(self, num_or_size_splits, axis=-1, validate_args=False, name='split'): """Creates the bijector. Args: num_or_size_splits: Either a Python integer indicating the number of splits along `axis` or a 1-D integer `Tensor` or Python list containing the sizes of each output tensor along `axis`. If a list/`Tensor`, it may contain at most one value of `-1`, which indicates a split size that is unknown and determined from input. axis: A negative integer or scalar `int32` `Tensor`. The dimension along which to split. Must be negative to enable the bijector to support arbitrary batch dimensions. Defaults to -1 (note that this is different from the `tf.Split` default of `0`). Must be statically known. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str`, name given to ops managed by this object. """ parameters = dict(locals()) with tf.name_scope(name) as name: if isinstance(num_or_size_splits, numbers.Integral): self._num_splits = num_or_size_splits self._split_sizes = None else: self._split_sizes = tensor_util.convert_nonref_to_tensor( num_or_size_splits, name='num_or_size_splits', dtype=tf.int32) if tensorshape_util.rank(self._split_sizes.shape) != 1: raise ValueError( '`num_or_size_splits` must be an integer or 1-D `Tensor`.' ) num_splits = tensorshape_util.as_list( self._split_sizes.shape)[0] if num_splits is None: raise ValueError( 'If `num_or_size_splits` is a vector of split sizes ' 'it must have a statically-known number of ' 'elements.') self._num_splits = num_splits static_axis = tf.get_static_value(axis) if static_axis is None: raise ValueError('`axis` must be statically known.') if static_axis >= 0: raise ValueError( '`axis` must be negative. Got {}'.format(axis)) self._axis = tf.convert_to_tensor(axis, tf.int32) super(Split, self).__init__(forward_min_event_ndims=-axis, inverse_min_event_ndims=[-axis] * self.num_splits, is_constant_jacobian=True, validate_args=validate_args, parameters=parameters, name=name)
def prepare_conv_args(filter_shape, rank, strides, padding, dilations, is_transpose=False, validate_args=False): """Sanitizes use provided input.""" padding = _validate_padding(padding) try: rank = int(tf.get_static_value(rank)) except TypeError: raise TypeError('Argument `rank` must be statically known `int`.') valid_rank = {1, 2, 3} if rank not in valid_rank: raise ValueError('Argument `rank` must be in {}.'.format(valid_rank)) filter_shape = prepare_tuple_argument(filter_shape, n=rank, arg_name='filter_shape', validate_args=validate_args) strides = prepare_tuple_argument(strides, n=rank, arg_name='strides', validate_args=validate_args) padding = _prepare_padding_argument(padding) dilations = prepare_tuple_argument(dilations, n=rank, arg_name='dilations', validate_args=validate_args) strides_ = [tf.get_static_value(s) for s in strides] dilations_ = [tf.get_static_value(d) for d in dilations] assertions = [] if is_transpose: if (all(s is not None for s in strides_) and all(d is not None for d in dilations_)): if any(s > 1 for s in strides_) and any(d > 1 for d in dilations_): raise NotImplementedError( 'At least one of `dilations` and `strides` ' 'must equal `1` for each dimension. Saw: ' '`strides={}`, `dilations={}`'.format(strides, dilations)) elif validate_args: assertions.append( assert_util.assert_equal( tf.logical_or(tf.equal(tf.reduce_max(strides), 1), tf.equal(tf.reduce_max(dilations), 1)), True, message= 'At least one of `dilations` and `strides` must equal `1` ' 'for each dimension.')) # TODO(emilyaf): Remove this once strides > filter_dim is supported. filter_shape_ = [tf.get_static_value(s) for s in filter_shape] if any(s is not None and f is not None and s > f for s, f in zip(strides_, filter_shape_)): raise NotImplementedError( 'Stride must be less than or equal to the ' 'filter size along each dimension.') with tf.control_dependencies(assertions): return filter_shape, rank, strides, padding, dilations
def __init__(self, component_ssms, constant_offset=0., observation_noise_scale=None, initial_state_prior=None, initial_step=0, validate_args=False, name=None, **linear_gaussian_ssm_kwargs): """Build a state space model representing the sum of component models. Args: component_ssms: Python `list` containing one or more `tfd.LinearGaussianStateSpaceModel` instances. The components will in general implement different time-series models, with possibly different `latent_size`, but they must have the same `dtype`, event shape (`num_timesteps` and `observation_size`), and their batch shapes must broadcast to a compatible batch shape. constant_offset: `float` `Tensor` of shape broadcasting to `concat([batch_shape, [num_timesteps]]`) specifying a constant value added to the sum of outputs from the component models. This allows the components to model the shifted series `observed_time_series - constant_offset`. Default value: `0.` observation_noise_scale: Optional scalar `float` `Tensor` indicating the standard deviation of the observation noise. May contain additional batch dimensions, which must broadcast with the batch shape of elements in `component_ssms`. If `observation_noise_scale` is specified for the `AdditiveStateSpaceModel`, the observation noise scales of component models are ignored. If `None`, the observation noise scale is derived by summing the noise variances of the component models, i.e., `observation_noise_scale = sqrt(sum( [ssm.observation_noise_scale**2 for ssm in component_ssms]))`. initial_state_prior: Optional instance of `tfd.MultivariateNormal` representing a prior distribution on the latent state at time `initial_step`. If `None`, defaults to the independent priors from component models, i.e., `[component.initial_state_prior for component in component_ssms]`. Default value: `None`. initial_step: Optional scalar `int` `Tensor` specifying the starting timestep. Default value: 0. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this class. Default value: "AdditiveStateSpaceModel". **linear_gaussian_ssm_kwargs: Optional additional keyword arguments to to the base `tfd.LinearGaussianStateSpaceModel` constructor. Raises: ValueError: if components have different `num_timesteps`. """ parameters = dict(locals()) parameters.update(linear_gaussian_ssm_kwargs) del parameters['linear_gaussian_ssm_kwargs'] with tf.name_scope(name or 'AdditiveStateSpaceModel') as name: # Check that all components have the same dtype dtype = tf.debugging.assert_same_float_dtype(component_ssms) # Convert scalar offsets to canonical shape `[..., num_timesteps]`. constant_offset = (tf.convert_to_tensor(value=constant_offset, name='constant_offset', dtype=dtype) * tf.ones([1], dtype=dtype)) offset_length = ps.shape(constant_offset)[-1] assertions = [] # Construct an initial state prior as a block-diagonal combination # of the component state priors. if initial_state_prior is None: initial_state_prior = sts_util.factored_joint_mvn( [ssm.initial_state_prior for ssm in component_ssms]) dtype = initial_state_prior.dtype static_num_timesteps = [ tf.get_static_value(ssm.num_timesteps) for ssm in component_ssms if tf.get_static_value(ssm.num_timesteps) is not None ] # If any components have a static value for `num_timesteps`, use that # value for the additive model. (and check that all other static values # match it). if static_num_timesteps: num_timesteps = static_num_timesteps[0] if not all([component_timesteps == num_timesteps for component_timesteps in static_num_timesteps]): raise ValueError('Additive model components must all have the same ' 'number of timesteps ' '(saw: {})'.format(static_num_timesteps)) else: num_timesteps = component_ssms[0].num_timesteps if validate_args and len(static_num_timesteps) != len(component_ssms): assertions += [ tf.debugging.assert_equal( # pylint: disable=g-complex-comprehension num_timesteps, ssm.num_timesteps, message='Additive model components must all have ' 'the same number of timesteps.') for ssm in component_ssms ] # Define the transition and observation models for the additive SSM. # See the "mathematical details" section of the class docstring for # further information. Note that we define these as callables to # handle the fully general case in which some components have time- # varying dynamics. def transition_matrix_fn(t): return tfl.LinearOperatorBlockDiag( [ssm.get_transition_matrix_for_timestep(t) for ssm in component_ssms]) def transition_noise_fn(t): return sts_util.factored_joint_mvn( [ssm.get_transition_noise_for_timestep(t) for ssm in component_ssms]) # Build the observation matrix, concatenating (broadcast) observation # matrices from components. We also take this as an opportunity to enforce # any dynamic assertions we may have generated above. broadcast_batch_shape = ps.cast( sts_util.broadcast_batch_shape( [ssm.get_observation_matrix_for_timestep(initial_step) for ssm in component_ssms]), dtype=tf.int32) broadcast_obs_matrix = tf.ones( ps.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype) if assertions: with tf.control_dependencies(assertions): broadcast_obs_matrix = tf.identity(broadcast_obs_matrix) def observation_matrix_fn(t): return tfl.LinearOperatorFullMatrix( tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() * broadcast_obs_matrix for ssm in component_ssms], axis=-1)) # Broadcast the constant offset across timesteps. offset_at_step = lambda t: ( # pylint: disable=g-long-lambda constant_offset if offset_length == 1 else tf.gather(constant_offset, tf.minimum(t, offset_length - 1), axis=-1)[..., tf.newaxis]) if observation_noise_scale is not None: observation_noise_scale = tf.convert_to_tensor( value=observation_noise_scale, name='observation_noise_scale', dtype=dtype) def observation_noise_fn(t): return tfd.MultivariateNormalDiag( loc=(sum([ssm.get_observation_noise_for_timestep(t).mean() for ssm in component_ssms]) + offset_at_step(t)), scale_diag=observation_noise_scale[..., tf.newaxis]) else: def observation_noise_fn(t): offset = offset_at_step(t) return sts_util.sum_mvns( [tfd.MultivariateNormalDiag( loc=offset, scale_diag=tf.zeros_like(offset))] + [ssm.get_observation_noise_for_timestep(t) for ssm in component_ssms]) super(AdditiveStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=transition_matrix_fn, transition_noise=transition_noise_fn, observation_matrix=observation_matrix_fn, observation_noise=observation_noise_fn, initial_state_prior=initial_state_prior, initial_step=initial_step, validate_args=validate_args, name=name, **linear_gaussian_ssm_kwargs) self._parameters = parameters
def _rank(input, name=None): # pylint: disable=redefined-builtin,unused-argument if not hasattr(input, 'shape'): input = (tf.convert_to_tensor(input) if tf.get_static_value(input) is None else np.array(input)) ndims_ = tensorshape_util.rank(getattr(input, 'shape', None)) return tf.rank(input) if ndims_ is None else np.int32(ndims_)
def _sample_n(self, n, seed): components_seed, mix_seed = samplers.split_seed( seed, salt='MixtureSameFamily') try: seed_stream = SeedStream(seed, salt='MixtureSameFamily') except TypeError as e: # Can happen for Tensor seeds. seed_stream = None seed_stream_err = e try: x = self.components_distribution.sample( # [n, B, k, E] n, seed=components_seed) if seed_stream is not None: seed_stream() # Advance even if unused. except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `components_distribution` ' '{} of type `{}`. Please update to use `tf.random.stateless_*` ' 'RNGs. This fallback may be removed after 20-Aug-2020. {}') warnings.warn( msg.format(self.components_distribution.name, type(self.components_distribution), str(e))) x = self.components_distribution.sample( # [n, B, k, E] n, seed=seed_stream()) event_shape = None event_ndims = tensorshape_util.rank(self.event_shape) if event_ndims is None: event_shape = self.components_distribution.event_shape_tensor() event_ndims = prefer_static.rank_from_shape(event_shape) event_ndims_static = tf.get_static_value(event_ndims) num_components = None if event_ndims_static is not None: num_components = tf.compat.dimension_value( x.shape[-1 - event_ndims_static]) # We could also check if num_components can be computed statically from # self.mixture_distribution's logits or probs. if num_components is None: num_components = tf.shape(x)[-1 - event_ndims] # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). npdt = dtype_util.as_numpy_dtype(x.dtype) try: mix_sample = self.mixture_distribution.sample( n, seed=mix_seed) # [n, B] or [n] except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `mixture_distribution` ' '{} of type `{}`. Please update to use `tf.random.stateless_*` ' 'RNGs. This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(self.mixture_distribution.name, type(self.mixture_distribution), str(e))) mix_sample = self.mixture_distribution.sample( n, seed=seed_stream()) # [n, B] or [n] mask = tf.one_hot( indices=mix_sample, # [n, B] or [n] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k] or [n, k] # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] . batch_ndims = prefer_static.rank(x) - event_ndims - 1 mask_batch_ndims = prefer_static.rank(mask) - 1 pad_ndims = batch_ndims - mask_batch_ndims mask_shape = prefer_static.shape(mask) mask = tf.reshape( mask, shape=prefer_static.concat([ mask_shape[:-1], prefer_static.ones([pad_ndims], dtype=tf.int32), mask_shape[-1:], prefer_static.ones([event_ndims], dtype=tf.int32), ], axis=0)) if x.dtype in [ tf.bfloat16, tf.float16, tf.float32, tf.float64, tf.complex64, tf.complex128 ]: masked = tf.math.multiply_no_nan(x, mask) else: masked = x * mask ret = tf.reduce_sum(masked, axis=-1 - event_ndims) # [n, B, E] if self._reparameterize: if event_shape is None: event_shape = self.components_distribution.event_shape_tensor() ret = self._reparameterize_sample(ret, event_shape=event_shape) return ret
def tensor_and_const_value(v): tensor_value = tf.convert_to_tensor(v) const_value = tf.get_static_value(tensor_value) return (tensor_value, const_value)