Esempio n. 1
0
def trace(
    state: State,
    fn: TransitionOperator,
    num_steps: IntTensor,
    trace_fn: Callable[[State, TensorNest], TensorNest],
    parallel_iterations: int = 10,
) -> Tuple[State, TensorNest]:
    """`TransitionOperator` that runs `fn` repeatedly and traces its outputs.

  Args:
    state: A nest of `Tensor`s or None.
    fn: A `TransitionOperator`.
    num_steps: Number of steps to run the function for. Must be greater than 1.
    trace_fn: Callable that the unpacked outputs of `fn` and returns a nest of
      `Tensor`s. These will be stacked and returned.
    parallel_iterations: Number of iterations of the while loop to run in
      parallel.

  Returns:
    state: The final state returned by `fn`.
    traces: Stacked outputs of `trace_fn`.
  """
    state = tf.nest.map_structure(
        lambda t: t if t is None else tf.convert_to_tensor(t), state)

    def wrapper(state):
        state, extra = tf.nest.map_structure(tf.convert_to_tensor,
                                             call_fn(fn, state))
        trace_element = tf.nest.map_structure(tf.convert_to_tensor,
                                              trace_fn(state, extra))
        return state, trace_element

    if any(e is None
           for e in tf.nest.flatten(state)) or tf.executing_eagerly():
        state, first_trace = wrapper(state)
        trace_arrays = tf.nest.map_structure(
            lambda v: tf.TensorArray(  # pylint: disable=g-long-lambda
                v.dtype,
                size=num_steps,
                element_shape=v.shape).write(0, v),
            first_trace)
        start_idx = 1
    else:
        state_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state)
        # We need the shapes and dtypes of the outputs of `wrapper` function to
        # create the `TensorArray`s, we can get it by pre-compiling the wrapper
        # function.
        wrapper = tf.function(autograph=False)(wrapper)
        concrete_wrapper = wrapper.get_concrete_function(state_spec)
        _, trace_dtypes = concrete_wrapper.output_dtypes
        _, trace_shapes = concrete_wrapper.output_shapes
        trace_arrays = tf.nest.map_structure(
            lambda dtype, shape: tf.TensorArray(  # pylint: disable=g-long-lambda
                dtype,
                size=num_steps,
                element_shape=shape),
            trace_dtypes,
            trace_shapes)
        wrapper = lambda state: concrete_wrapper(*tf.nest.flatten(state))
        start_idx = 0

    def body(i, state, trace_arrays):
        state, trace_element = wrapper(state)
        trace_arrays = tf.nest.map_structure(lambda a, v: a.write(i, v),
                                             trace_arrays, trace_element)
        return i + 1, state, trace_arrays

    def cond(i, *_):
        return i < num_steps

    _, state, trace_arrays = tf.while_loop(
        cond=cond,
        body=body,
        loop_vars=(start_idx, state, trace_arrays),
        parallel_iterations=parallel_iterations)

    stacked_trace = tf.nest.map_structure(lambda x: x.stack(), trace_arrays)

    static_length = tf.get_static_value(num_steps)

    def _merge_static_length(x):
        x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:]))
        return x

    stacked_trace = tf.nest.map_structure(_merge_static_length, stacked_trace)

    return state, stacked_trace
Esempio n. 2
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 max_tree_depth=10,
                 max_energy_diff=1000.,
                 unrolled_leapfrog_steps=1,
                 seed=None,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The
        maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e.
        the number of nodes in a binary tree `max_tree_depth` nodes deep. The
        default setting of 10 takes up to 1024 leapfrog steps.
      max_energy_diff: Scaler threshold of energy differences at each leapfrog,
        divergence samples are defined as leapfrog steps that exceed this
        threshold. Default to 1000.
      unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree
        expansion step. Applies a direct linear multipler to the maximum
        trajectory length implied by max_tree_depth. Defaults to 1.
      seed: Python integer to seed the random number generator.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'nuts_kernel').
    """
        with tf.name_scope(name or 'NoUTurnSampler') as name:
            # Process `max_tree_depth` argument.
            max_tree_depth = tf.get_static_value(max_tree_depth)
            if max_tree_depth is None or max_tree_depth < 1:
                raise ValueError(
                    'max_tree_depth must be known statically and >= 1 but was '
                    '{}'.format(max_tree_depth))
            self._max_tree_depth = max_tree_depth

            # Compute parameters derived from `max_tree_depth`.
            instruction_array = build_tree_uturn_instruction(max_tree_depth,
                                                             init_memory=-1)
            [write_instruction_numpy, read_instruction_numpy
             ] = generate_efficient_write_read_instruction(instruction_array)

            # TensorArray version of the read/write instruction need to be created
            # within the function call to be compatible with XLA. Here we store the
            # numpy version of the instruction and convert it to TensorArray later.
            self._write_instruction = write_instruction_numpy
            self._read_instruction = read_instruction_numpy

            # Process all other arguments.
            self._target_log_prob_fn = target_log_prob_fn
            if not tf.nest.is_nested(step_size):
                step_size = [step_size]
            step_size = [
                tf.convert_to_tensor(s, dtype_hint=tf.float32)
                for s in step_size
            ]
            self._step_size = step_size

            self._parameters = dict(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                max_tree_depth=max_tree_depth,
                max_energy_diff=max_energy_diff,
                unrolled_leapfrog_steps=unrolled_leapfrog_steps,
                seed=seed,
                name=name,
            )
            self._seed_stream = SeedStream(seed, salt='nuts_one_step')
            self._unrolled_leapfrog_steps = unrolled_leapfrog_steps
            self._name = name
            self._max_energy_diff = max_energy_diff
Esempio n. 3
0
def canonicalize_observed_time_series_with_mask(
        maybe_masked_observed_time_series):
    """Extract a Tensor with canonical shape and optional mask.

  Args:
    maybe_masked_observed_time_series: a `Tensor`-like object with shape
      `[..., num_timesteps]` or `[..., num_timesteps, 1]`, or a
      `tfp.sts.MaskedTimeSeries` containing such an object, or a Pandas
      Series or DataFrame instance with set frequency
      (i.e., `.index.freq is not None`).
  Returns:
    masked_time_series: a `tfp.sts.MaskedTimeSeries` namedtuple, in which
      the `observed_time_series` is converted to `Tensor` with canonical shape
      `[..., num_timesteps, 1]`, and `is_missing` is either `None` or a boolean
      `Tensor`.
  """

    with tf.name_scope('canonicalize_observed_time_series_with_mask'):

        is_missing_is_specified = hasattr(maybe_masked_observed_time_series,
                                          'is_missing')
        if is_missing_is_specified:
            # Input is a MaskedTimeSeries.
            observed_time_series = (
                maybe_masked_observed_time_series.time_series)
            is_missing = maybe_masked_observed_time_series.is_missing
        elif (hasattr(maybe_masked_observed_time_series, 'index')
              and hasattr(maybe_masked_observed_time_series, 'to_numpy')):
            # Input is a Pandas Series or DataFrame.
            index = maybe_masked_observed_time_series.index
            if hasattr(index, 'freq') and index.freq is None:
                raise ValueError(
                    'Pandas DataFrame or Series has a DatetimeIndex with '
                    'no set frequency, but STS requires regularly spaced '
                    'observations. Consider using '
                    '`tfp.sts.regularize_series` to infer a frequency and '
                    'build a regularly spaced series (by marking '
                    'unobserved steps as missing observations).')
            # When a DataFrame has multiple columns representing a batch of series,
            # we want shape `[batch_size, num_steps]` rather than vice versa.
            observed_time_series = np.squeeze(
                np.transpose(maybe_masked_observed_time_series.to_numpy()))
        else:
            observed_time_series = maybe_masked_observed_time_series

        observed_time_series = tf.convert_to_tensor(
            value=observed_time_series, name='observed_time_series')
        observed_time_series = _maybe_expand_trailing_dim(observed_time_series)

        # Treat `NaN` values as missing.
        if not is_missing_is_specified:
            is_missing = tf.math.is_nan(observed_time_series[..., 0])
        is_missing_static = tf.get_static_value(is_missing)
        if is_missing_static is not None and not np.any(is_missing_static):
            is_missing = None
        if is_missing is not None:
            is_missing = tf.convert_to_tensor(value=is_missing,
                                              name='is_missing',
                                              dtype_hint=tf.bool)

        return missing_values_util.MaskedTimeSeries(observed_time_series,
                                                    is_missing=is_missing)
Esempio n. 4
0
    def __init__(self,
                 perm=None,
                 rightmost_transposed_ndims=None,
                 validate_args=False,
                 name='transpose'):
        """Instantiates the `Transpose` bijector.

    Args:
      perm: Positive `int32` vector-shaped `Tensor` representing permutation of
        rightmost dims (for forward transformation).  Note that the `0`th index
        represents the first of the rightmost dims and the largest value must be
        `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value:
        `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`.
      rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor`
        representing the number of rightmost dimensions to permute.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value: `tf.size(perm)`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are
        specified.
      NotImplementedError: if `rightmost_transposed_ndims` is not known prior to
        graph execution.
    """
        with tf.name_scope(name):
            if (rightmost_transposed_ndims is None) == (perm is None):
                raise ValueError('Must specify exactly one of '
                                 '`rightmost_transposed_ndims` and `perm`.')
            if rightmost_transposed_ndims is not None:
                rightmost_transposed_ndims = tf.convert_to_tensor(
                    value=rightmost_transposed_ndims,
                    dtype=np.int32,
                    name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_rightmost_transposed_ndims(
                    rightmost_transposed_ndims, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        rightmost_transposed_ndims = tf.identity(
                            rightmost_transposed_ndims)
                perm_start = (distribution_util.prefer_static_value(
                    rightmost_transposed_ndims) - 1)
                perm = tf.range(start=perm_start,
                                limit=-1,
                                delta=-1,
                                name='perm')
            else:  # perm is not None:
                perm = tf.convert_to_tensor(value=perm,
                                            dtype=np.int32,
                                            name='perm')
                rightmost_transposed_ndims = tf.size(
                    input=perm, name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_perm(perm, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        perm = tf.identity(perm)

            # TODO(b/110828604): If bijector base class ever supports dynamic
            # `min_event_ndims`, then this class already works dynamically and the
            # following five lines can be removed.
            if rightmost_transposed_ndims_ is None:
                raise NotImplementedError(
                    '`rightmost_transposed_ndims` must be '
                    'known prior to graph execution.')
            else:
                rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_)

            self._perm = perm
            self._rightmost_transposed_ndims = rightmost_transposed_ndims
            super(Transpose, self).__init__(
                forward_min_event_ndims=rightmost_transposed_ndims_,
                graph_parents=[perm, rightmost_transposed_ndims],
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
Esempio n. 5
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'snaper_hamiltonian_monte_carlo',
                                    'bootstrap_results')):
            init_state = tf.nest.map_structure(
                lambda x: tf.convert_to_tensor(x, name='init_state'),
                init_state)

            # It is unfortunate that we need to make this extra call to the TLP here.
            # The issue is that we need this value to even construct the PHMC, and
            # the kernel will call this one itself.
            tlp = self.target_log_prob_fn(*tf.nest.flatten(init_state))
            batch_shape = ps.shape(tlp)
            batch_ndims = ps.rank(tlp)
            if tf.get_static_value(batch_ndims) is None:
                # The issue doesn't live in this file, rather it is the downstream
                # components that fail to work (notably, tfb.Reshape).
                raise ValueError(
                    'SNAPERHMC currently requires a statically known '
                    'rank of the target log probability.')

            # We need at least two chains to estimate the principal component.
            # Number of total chains is local batch size * distributed axis size
            reduce_chain_axis_names = distribute_lib.canonicalize_named_axis(
                self.experimental_reduce_chain_axis_names)
            local_axis_size = ps.maximum(ps.size(tlp), 1)
            distributed_axis_size = int(
                ps.reduce_prod([
                    distribute_lib.get_axis_size(a)
                    for a in reduce_chain_axis_names
                ]))
            num_chains = local_axis_size * distributed_axis_size
            num_chains_ = tf.get_static_value(num_chains)
            if num_chains_ is not None:
                if num_chains_ < 2:
                    raise ValueError(
                        'SNAPERHMC requires at least 2 chains. Got: {}'.format(
                            num_chains_))
            elif self.validate_args:
                with tf.control_dependencies([
                        assert_util.assert_greater_equal(
                            num_chains, 2,
                            'SNAPERHMC requires at least 2 chains.')
                ]):
                    init_state = tf.nest.map_structure(tf.identity, init_state)

            event_axes = tf.nest.map_structure(
                lambda x: ps.range(batch_ndims, ps.rank(x)) - ps.rank(x),
                init_state)
            if self.experimental_shard_axis_names is None:
                shard_axis_names = tf.nest.map_structure(
                    lambda _: None, init_state)
            else:
                shard_axis_names = self.experimental_shard_axis_names

            ema_variance = tf.nest.map_structure(
                lambda x: tf.ones(  # pylint: disable=g-long-lambda
                    ps.shape(x)[batch_ndims:],
                    dtype=x.dtype,
                    name='ema_variance'),
                init_state)
            ema_mean = tf.nest.map_structure(
                lambda x: tf.zeros_like(x, name='ema_mean'), ema_variance)
            ema_principal_component = _normalize(ema_variance, event_axes,
                                                 shard_axis_names)
            # These start out at 1 for a bit of smoothing.
            state_ema_points = tf.ones([], tf.int32)
            principal_component_ema_points = tf.ones([], tf.int32)

            kernel = self._make_kernel(
                batch_shape=batch_shape,
                step=tf.zeros([], tf.int32),
                state_ema_points=state_ema_points,
                state=init_state,
                mean=ema_mean,
                variance=ema_variance,
                principal_component=ema_principal_component,
            )

            inner_results = kernel.bootstrap_results(
                tf.nest.flatten(init_state))

            kernel_results = SNAPERHamiltonianMonteCarloResults(
                inner_results=inner_results,
                ema_mean=ema_mean,
                ema_variance=ema_variance,
                state_ema_points=state_ema_points,
                ema_principal_component=ema_principal_component,
                principal_component_ema_points=principal_component_ema_points,
                seed=samplers.zeros_seed(),
            )
            return kernel_results
def independent_joint_distribution_from_structure(structure_of_distributions,
                                                  batch_ndims=None,
                                                  validate_args=False):
    """Turns a (potentially nested) structure of dists into a single dist.

  Args:
    structure_of_distributions: instance of `tfd.Distribution`, or nested
      structure (tuple, list, dict, etc.) in which all leaves are
      `tfd.Distribution` instances.
    batch_ndims: Optional integer `Tensor` number of leftmost batch dimensions
      shared across all members of the input structure. If this is specified,
      the returned joint distribution will be an autobatched distribution with
      the given batch rank, and all other dimensions absorbed into the event.
    validate_args: Python `bool`. Whether the joint distribution should validate
      input with asserts. This imposes a runtime cost. If `validate_args` is
      `False`, and the inputs are invalid, correct behavior is not guaranteed.
      Default value: `False`.
  Returns:
    distribution: instance of `tfd.Distribution` such that
      `distribution.sample()` is equivalent to
      `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`.
      If `structure_of_distributions` was indeed a structure (as opposed to
      a single `Distribution` instance), this will be a `JointDistribution`
      with the corresponding structure.
  Raises:
    TypeError: if any leaves of the input structure are not `tfd.Distribution`
      instances.
  """
    # If input is already a Distribution, just return it.
    if dist_util.is_distribution_instance(structure_of_distributions):
        dist = structure_of_distributions
        if batch_ndims is not None:
            excess_ndims = ps.rank_from_shape(
                dist.batch_shape_tensor()) - batch_ndims
            if tf.get_static_value(
                    excess_ndims) != 0:  # Static value may be None.
                dist = independent.Independent(
                    dist, reinterpreted_batch_ndims=excess_ndims)
        return dist

    # If this structure contains other structures (ie, has elements at depth > 1),
    # recursively turn them into JDs.
    element_depths = nest.map_structure_with_tuple_paths(
        lambda path, x: len(path), structure_of_distributions)
    if max(tf.nest.flatten(element_depths)) > 1:
        next_level_shallow_structure = nest.get_traverse_shallow_structure(
            traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1,
            structure=element_depths)
        structure_of_distributions = nest.map_structure_up_to(
            next_level_shallow_structure,
            functools.partial(independent_joint_distribution_from_structure,
                              batch_ndims=batch_ndims,
                              validate_args=validate_args),
            structure_of_distributions)

    jdnamed = joint_distribution_named.JointDistributionNamed
    jdsequential = joint_distribution_sequential.JointDistributionSequential
    # Use an autobatched JD if a specific batch rank was requested.
    if batch_ndims is not None:
        jdnamed = functools.partial(
            joint_distribution_auto_batched.JointDistributionNamedAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)
        jdsequential = functools.partial(
            joint_distribution_auto_batched.
            JointDistributionSequentialAutoBatched,
            batch_ndims=batch_ndims,
            use_vectorized_map=False)

    # Otherwise, build a JD from the current structure.
    if (hasattr(structure_of_distributions, '_asdict')
            or isinstance(structure_of_distributions, collections.Mapping)):
        return jdnamed(structure_of_distributions, validate_args=validate_args)
    return jdsequential(structure_of_distributions,
                        validate_args=validate_args)
Esempio n. 7
0
  def __init__(self,
               distribution,
               reinterpreted_batch_ndims=None,
               validate_args=False,
               experimental_use_kahan_sum=False,
               name=None):
    """Construct an `Independent` distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims
        which will be regarded as event dims. When `None` all but the first
        batch axis (batch axis 0) will be transferred to event dimensions
        (analogous to `tf.layers.flatten`).
      validate_args: Python `bool`.  Whether to validate input with asserts.
        If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
        summation to aggregate independent underlying log_prob values, which
        improves against the precision of a naive float32 sum. This can be
        noticeable in particular for large dimensions in float32. See CPU caveat
        on `tfp.math.reduce_kahan_sum`.
      name: The name for ops managed by the distribution.
        Default value: `Independent + distribution.name`.

    Raises:
      ValueError: if `reinterpreted_batch_ndims` exceeds
        `distribution.batch_ndims`
    """
    parameters = dict(locals())
    self._experimental_use_kahan_sum = experimental_use_kahan_sum
    with tf.name_scope(name or ('Independent' + distribution.name)) as name:
      self._distribution = distribution

      if reinterpreted_batch_ndims is None:
        # If possible, statically infer reinterpreted_batch_ndims.
        batch_ndims = tensorshape_util.rank(distribution.batch_shape)
        if batch_ndims is not None:
          self._static_reinterpreted_batch_ndims = max(0, batch_ndims - 1)
          self._reinterpreted_batch_ndims = tf.convert_to_tensor(
              self._static_reinterpreted_batch_ndims,
              dtype_hint=tf.int32,
              name='reinterpreted_batch_ndims')
        else:
          self._reinterpreted_batch_ndims = None
          self._static_reinterpreted_batch_ndims = None

      else:
        self._reinterpreted_batch_ndims = tensor_util.convert_nonref_to_tensor(
            reinterpreted_batch_ndims,
            dtype_hint=tf.int32,
            name='reinterpreted_batch_ndims')
        static_val = tf.get_static_value(self._reinterpreted_batch_ndims)
        self._static_reinterpreted_batch_ndims = (
            None if static_val is None else int(static_val))

      super(Independent, self).__init__(
          dtype=self._distribution.dtype,
          reparameterization_type=self._distribution.reparameterization_type,
          validate_args=validate_args,
          allow_nan_stats=self._distribution.allow_nan_stats,
          parameters=parameters,
          name=name)
Esempio n. 8
0
def make_convolution_transpose_fn_with_dilation(filter_shape,
                                                strides,
                                                padding,
                                                rank=2,
                                                dilations=None,
                                                dtype=tf.int32,
                                                validate_args=False,
                                                name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.

  This version tends to be fastest on GPU. It implements the transposed
  convolution as a regular convolution of an image that is dilated by
  interleaving rows and columns of zeros equal to the number of strides.

  Args:
    filter_shape: ...
    strides: ...
    padding: ...
    rank: ...
    dilations: ...
    dtype: ...
    validate_args: ...
    name: ...
  Returns:
    convolution_transpose_fn: A callable that takes an input `Tensor` and kernel
      and applies the transpose convolution operation.
  """
    with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):

        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))
        [
            filter_shape,
            rank,
            strides,
            padding,
            dilations,
        ] = prepare_conv_args(filter_shape,
                              rank=rank,
                              strides=strides,
                              padding=padding,
                              dilations=dilations,
                              is_transpose=True,
                              validate_args=validate_args)

        sh, sw = strides
        fh, fw = filter_shape

        pad_values = [
            _get_transpose_conv_dilated_padding(k,
                                                stride=s,
                                                dilation=d,
                                                padding=padding)
            for (k, s, d) in zip(filter_shape, strides, dilations)
        ]

        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)
            kernel_shape = ps.shape(kernel)
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  kernel_shape[-1],
                                                  batch_shape, event_shape)

                idx, shape = im2row_index((xh * sh + sum(pad_values[0]),
                                           xw * sw + sum(pad_values[1]), c_in),
                                          block_shape=filter_shape,
                                          slice_step=(1, 1),
                                          dilations=dilations,
                                          dtype=dtype,
                                          transpose=True)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                # Interleave the rows and columns of the input with rows and columns of
                # zeros equal to the number of strides.
                x_half_dilated = tf.concat([
                    tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)],
                                       axis=0),
                             dtype=input_dtype),
                    tf.reshape(x,
                               shape=ps.concat(
                                   [batch_shape, (xh * xw, 1, c_in)], axis=0))
                ],
                                           axis=-2)
                y = tf.reshape(x_half_dilated,
                               shape=ps.concat(
                                   [batch_shape, (xh, 1, xw * sw, c_in)],
                                   axis=0))

                x = tf.reshape(tf.concat([
                    tf.zeros(ps.concat(
                        [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0),
                             dtype=input_dtype), y
                ],
                                         axis=-3),
                               shape=ps.concat(
                                   [batch_shape, (xh * sh, xw * sw, c_in)],
                                   axis=0))
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                flat_shape = ps.pad(batch_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape),
                                   indices=idx,
                                   axis=-1)
                im_x = tf.reshape(flat_x,
                                  shape=ps.concat([batch_shape, shape],
                                                  axis=0))
                return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])

        return op
Esempio n. 9
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)
            kernel_shape = ps.shape(kernel)
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  kernel_shape[-1],
                                                  batch_shape, event_shape)

                idx, shape = im2row_index((xh * sh + sum(pad_values[0]),
                                           xw * sw + sum(pad_values[1]), c_in),
                                          block_shape=filter_shape,
                                          slice_step=(1, 1),
                                          dilations=dilations,
                                          dtype=dtype,
                                          transpose=True)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                # Interleave the rows and columns of the input with rows and columns of
                # zeros equal to the number of strides.
                x_half_dilated = tf.concat([
                    tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)],
                                       axis=0),
                             dtype=input_dtype),
                    tf.reshape(x,
                               shape=ps.concat(
                                   [batch_shape, (xh * xw, 1, c_in)], axis=0))
                ],
                                           axis=-2)
                y = tf.reshape(x_half_dilated,
                               shape=ps.concat(
                                   [batch_shape, (xh, 1, xw * sw, c_in)],
                                   axis=0))

                x = tf.reshape(tf.concat([
                    tf.zeros(ps.concat(
                        [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0),
                             dtype=input_dtype), y
                ],
                                         axis=-3),
                               shape=ps.concat(
                                   [batch_shape, (xh * sh, xw * sw, c_in)],
                                   axis=0))
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                flat_shape = ps.pad(batch_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape),
                                   indices=idx,
                                   axis=-1)
                im_x = tf.reshape(flat_x,
                                  shape=ps.concat([batch_shape, shape],
                                                  axis=0))
                return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
Esempio n. 10
0
def make_convolution_fn(filter_shape,
                        rank,
                        strides,
                        padding,
                        dilations=None,
                        dtype=tf.int32,
                        validate_args=False,
                        name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    with tf.name_scope(name or 'conv2d'):
        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))
        [
            filter_shape,
            rank,
            strides,
            padding,
            dilations,
        ] = prepare_conv_args(filter_shape,
                              rank=rank,
                              strides=strides,
                              padding=padding,
                              dilations=dilations,
                              validate_args=validate_args)

    def op(x, kernel):
        input_dtype = dtype_util.common_dtype([x, kernel],
                                              dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')

        batch_shape, event_shape = ps.split(ps.shape(x),
                                            num_or_size_splits=[-1, 3])
        xh, xw, c_in = ps.unstack(event_shape, num=3)
        fh, fw = filter_shape

        assertions = _maybe_validate_input_shapes(ps.shape(kernel),
                                                  channels_in=c_in,
                                                  filter_height=fh,
                                                  filter_width=fw,
                                                  validate_args=validate_args)

        with tf.control_dependencies(assertions):
            if tf.get_static_value(ps.rank(kernel)) == 2:
                flat_x = tf.reshape(x,
                                    shape=ps.concat([[-1], event_shape],
                                                    axis=0))
                flat_y = tf.nn.conv2d(x,
                                      filters=tf.reshape(
                                          kernel, shape=[fh, fw, c_in, -1]),
                                      strides=strides,
                                      padding=padding,
                                      data_format='NHWC',
                                      dilations=dilations)
                output_shape = ps.shape(flat_y)[-3:]
                return tf.reshape(flat_y,
                                  shape=ps.concat([batch_shape, output_shape],
                                                  axis=0))

            pad_values = [
                _get_conv_padding(xdim,
                                  filter_dim=k,
                                  stride=s,
                                  dilation=d,
                                  padding=padding)
                for (xdim, k, s,
                     d) in zip((xh, xw), filter_shape, strides, dilations)
            ]

            idx, shape = im2row_index(
                (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in),
                block_shape=filter_shape,
                slice_step=strides,
                dilations=dilations,
                dtype=dtype)

            if padding == 'SAME':
                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x = tf.pad(x, paddings=paddings, constant_values=0)

            flat_shape = ps.pad(batch_shape,
                                paddings=[[0, 1]],
                                constant_values=-1)
            flat_x = tf.gather(tf.reshape(x, shape=flat_shape),
                               indices=idx,
                               axis=-1)
            im_x = tf.reshape(flat_x,
                              shape=ps.concat([batch_shape, shape], axis=0))
            return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])

    return op
Esempio n. 11
0
    def op(x, kernel):
        input_dtype = dtype_util.common_dtype([x, kernel],
                                              dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')

        batch_shape, event_shape = ps.split(ps.shape(x),
                                            num_or_size_splits=[-1, 3])
        xh, xw, c_in = ps.unstack(event_shape, num=3)
        fh, fw = filter_shape

        assertions = _maybe_validate_input_shapes(ps.shape(kernel),
                                                  channels_in=c_in,
                                                  filter_height=fh,
                                                  filter_width=fw,
                                                  validate_args=validate_args)

        with tf.control_dependencies(assertions):
            if tf.get_static_value(ps.rank(kernel)) == 2:
                flat_x = tf.reshape(x,
                                    shape=ps.concat([[-1], event_shape],
                                                    axis=0))
                flat_y = tf.nn.conv2d(x,
                                      filters=tf.reshape(
                                          kernel, shape=[fh, fw, c_in, -1]),
                                      strides=strides,
                                      padding=padding,
                                      data_format='NHWC',
                                      dilations=dilations)
                output_shape = ps.shape(flat_y)[-3:]
                return tf.reshape(flat_y,
                                  shape=ps.concat([batch_shape, output_shape],
                                                  axis=0))

            pad_values = [
                _get_conv_padding(xdim,
                                  filter_dim=k,
                                  stride=s,
                                  dilation=d,
                                  padding=padding)
                for (xdim, k, s,
                     d) in zip((xh, xw), filter_shape, strides, dilations)
            ]

            idx, shape = im2row_index(
                (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in),
                block_shape=filter_shape,
                slice_step=strides,
                dilations=dilations,
                dtype=dtype)

            if padding == 'SAME':
                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x = tf.pad(x, paddings=paddings, constant_values=0)

            flat_shape = ps.pad(batch_shape,
                                paddings=[[0, 1]],
                                constant_values=-1)
            flat_x = tf.gather(tf.reshape(x, shape=flat_shape),
                               indices=idx,
                               axis=-1)
            im_x = tf.reshape(flat_x,
                              shape=ps.concat([batch_shape, shape], axis=0))
            return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
Esempio n. 12
0
 def _batch_shape(self):
     return tf.nest.map_structure(
         lambda b: tensorshape_util.concatenate(  # pylint: disable=g-long-lambda
             [tf.get_static_value(self.num_particles)], b),
         self.distribution.batch_shape)
Esempio n. 13
0
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
    """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
    # Implementation details:
    # Extend length N / 2 1-D array x to length N by zero padding onto the end.
    # Then, set
    #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
    # It is not hard to see that
    #   F[x]_k Conj(F[x]_k) = F[R]_k, where
    #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
    # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

    # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
    # based version of estimating RXX.
    # Note that this is a special case of the Wiener-Khinchin Theorem.
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        # Rotate dimensions of x in order to put axis at the rightmost dim.
        # FFT op requires this.
        rank = prefer_static.rank(x)
        if axis < 0:
            axis = rank + axis
        shift = rank - 1 - axis
        # Suppose x.shape[axis] = T, so there are T 'time' steps.
        #   ==> x_rotated.shape = B + [T],
        # where B is x_rotated's batch shape.
        x_rotated = distribution_util.rotate_transpose(x, shift)

        if center:
            x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True)

        # x_len = N / 2 from above explanation.  The length of x along axis.
        # Get a value for x_len that works in all cases.
        x_len = prefer_static.shape(x_rotated)[-1]

        # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
        # the moment is necessary so that all FFT implementations work.
        # Zero pad to the next power of 2 greater than 2 * x_len, which equals
        # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
        x_len_float64 = tf.cast(x_len, np.float64)
        target_length = tf.pow(
            np.float64(2.),
            tf.math.ceil(tf.math.log(x_len_float64 * 2) / np.log(2.)))
        pad_length = tf.cast(target_length - x_len_float64, np.int32)

        # We should have:
        # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
        #                     = B + [T + pad_length]
        x_rotated_pad = distribution_util.pad(x_rotated,
                                              axis=-1,
                                              back=True,
                                              count=pad_length)

        dtype = x.dtype
        if not dtype_util.is_complex(dtype):
            if not dtype_util.is_floating(dtype):
                raise TypeError(
                    'Argument x must have either float or complex dtype'
                    ' found: {}'.format(dtype))
            x_rotated_pad = tf.complex(
                x_rotated_pad,
                dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.))

        # Autocorrelation is IFFT of power-spectral density (up to some scaling).
        fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
        spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
        # shifted_product is R[m] from above detailed explanation.
        # It is the inner product sum_n X[n] * Conj(X[n - m]).
        shifted_product = tf.signal.ifft(spectral_density)

        # Cast back to real-valued if x was real to begin with.
        shifted_product = tf.cast(shifted_product, dtype)

        # Figure out if we can deduce the final static shape, and set max_lags.
        # Use x_rotated as a reference, because it has the time dimension in the far
        # right, and was created before we performed all sorts of crazy shape
        # manipulations.
        know_static_shape = True
        if not tensorshape_util.is_fully_defined(x_rotated.shape):
            know_static_shape = False
        if max_lags is None:
            max_lags = x_len - 1
        else:
            max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
            max_lags_ = tf.get_static_value(max_lags)
            if max_lags_ is None or not know_static_shape:
                know_static_shape = False
                max_lags = tf.minimum(x_len - 1, max_lags)
            else:
                max_lags = min(x_len - 1, max_lags_)

        # Chop off the padding.
        # We allow users to provide a huge max_lags, but cut it off here.
        # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
        shifted_product_chopped = shifted_product[..., :max_lags + 1]

        # If possible, set shape.
        if know_static_shape:
            chopped_shape = tensorshape_util.as_list(x_rotated.shape)
            chopped_shape[-1] = min(x_len, max_lags + 1)
            shifted_product_chopped.set_shape(chopped_shape)

        # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
        # other terms were zeros arising only due to zero padding.
        # `denominator = (N / 2 - m)` (defined below) is the proper term to
        # divide by to make this an unbiased estimate of the expectation
        # E[X[n] Conj(X[n - m])].
        x_len = tf.cast(x_len, dtype_util.real_dtype(dtype))
        max_lags = tf.cast(max_lags, dtype_util.real_dtype(dtype))
        denominator = x_len - tf.range(0., max_lags + 1.)
        denominator = tf.cast(denominator, dtype)
        shifted_product_rotated = shifted_product_chopped / denominator

        if normalize:
            shifted_product_rotated /= shifted_product_rotated[..., :1]

        # Transpose dimensions back to those of x.
        return distribution_util.rotate_transpose(shifted_product_rotated,
                                                  -shift)
    def testCopy(self):
        # 5 random index points in R^2
        index_points_1 = np.random.uniform(-4., 4., (5, 2)).astype(np.float32)
        # 10 random index points in R^2
        index_points_2 = np.random.uniform(-4., 4., (10, 2)).astype(np.float32)

        observation_index_points_1 = (np.random.uniform(
            -4., 4., (7, 2)).astype(np.float32))
        observation_index_points_2 = (np.random.uniform(
            -4., 4., (9, 2)).astype(np.float32))

        observations_1 = np.random.uniform(-1., 1., 7).astype(np.float32)
        observations_2 = np.random.uniform(-1., 1., 9).astype(np.float32)

        # ==> shape = [6, 25, 2]
        if not self.is_static:
            index_points_1 = tf1.placeholder_with_default(index_points_1,
                                                          shape=None)
            index_points_2 = tf1.placeholder_with_default(index_points_2,
                                                          shape=None)
            observation_index_points_1 = tf1.placeholder_with_default(
                observation_index_points_1, shape=None)
            observation_index_points_2 = tf1.placeholder_with_default(
                observation_index_points_2, shape=None)
            observations_1 = tf1.placeholder_with_default(observations_1,
                                                          shape=None)
            observations_2 = tf1.placeholder_with_default(observations_2,
                                                          shape=None)

        mean_fn = lambda x: np.array([0.], np.float32)
        kernel_1 = psd_kernels.ExponentiatedQuadratic()
        kernel_2 = psd_kernels.ExpSinSquared()

        gprm1 = tfd.GaussianProcessRegressionModel(
            kernel=kernel_1,
            index_points=index_points_1,
            observation_index_points=observation_index_points_1,
            observations=observations_1,
            mean_fn=mean_fn,
            jitter=1e-5,
            validate_args=True)
        gprm2 = gprm1.copy(kernel=kernel_2,
                           index_points=index_points_2,
                           observation_index_points=observation_index_points_2,
                           observations=observations_2)

        precomputed_gprm1 = (
            tfd.GaussianProcessRegressionModel.precompute_regression_model(
                kernel=kernel_1,
                index_points=index_points_1,
                observation_index_points=observation_index_points_1,
                observations=observations_1,
                mean_fn=mean_fn,
                jitter=1e-5,
                validate_args=True))
        precomputed_gprm2 = precomputed_gprm1.copy(index_points=index_points_2)
        self.assertIs(precomputed_gprm1.mean_fn, precomputed_gprm2.mean_fn)
        self.assertIs(precomputed_gprm1.kernel, precomputed_gprm2.kernel)

        event_shape_1 = [5]
        event_shape_2 = [10]

        self.assertIsInstance(gprm1.kernel.base_kernel,
                              psd_kernels.ExponentiatedQuadratic)
        self.assertIsInstance(gprm2.kernel.base_kernel,
                              psd_kernels.ExpSinSquared)

        if self.is_static or tf.executing_eagerly():
            self.assertAllEqual(gprm1.batch_shape, gprm2.batch_shape)
            self.assertAllEqual(gprm1.event_shape, event_shape_1)
            self.assertAllEqual(gprm2.event_shape, event_shape_2)
            self.assertAllEqual(gprm1.index_points, index_points_1)
            self.assertAllEqual(gprm2.index_points, index_points_2)
            self.assertAllEqual(tf.get_static_value(gprm1.jitter),
                                tf.get_static_value(gprm2.jitter))
        else:
            self.assertAllEqual(self.evaluate(gprm1.batch_shape_tensor()),
                                self.evaluate(gprm2.batch_shape_tensor()))
            self.assertAllEqual(self.evaluate(gprm1.event_shape_tensor()),
                                event_shape_1)
            self.assertAllEqual(self.evaluate(gprm2.event_shape_tensor()),
                                event_shape_2)
            self.assertEqual(self.evaluate(gprm1.jitter),
                             self.evaluate(gprm2.jitter))
            self.assertAllEqual(self.evaluate(gprm1.index_points),
                                index_points_1)
            self.assertAllEqual(self.evaluate(gprm2.index_points),
                                index_points_2)
Esempio n. 15
0
def _zeros_like(input, dtype=None, name=None):  # pylint: disable=redefined-builtin
    s = _shape(input)
    s_ = tf.get_static_value(s)
    if s_ is not None:
        return np.zeros(s, _numpy_dtype(dtype or input.dtype))
    return tf.zeros(s, dtype or s.dtype, name)
Esempio n. 16
0
def make_convolution_transpose_fn_with_subkernels_matrix(
        filter_shape,
        strides,
        padding,
        rank=2,
        dilations=None,
        dtype=tf.int32,
        validate_args=False,
        name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):

        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))

        strides = tf.get_static_value(strides)
        if not isinstance(strides, int):
            raise ValueError(
                'Argument `strides` must be a statically known integer.'
                'Saw: {}'.format(strides))

        [
            filter_shape,
            rank,
            _,
            padding,
            dilations,
        ] = prepare_conv_args(filter_shape,
                              rank=rank,
                              strides=strides,
                              padding=padding,
                              dilations=dilations,
                              is_transpose=True,
                              validate_args=validate_args)

        fh, fw = filter_shape
        dh, dw = dilations

        # Determine maximum filter height and filter width of sub-kernels.
        sub_fh = (fh - 1) // strides + 1
        sub_fw = (fw - 1) // strides + 1

        def loop_body(i_, event_ind):
            i = i_ // strides
            j = i_ % strides

            i_ind = ps.range(i * fw, fw * fh, delta=strides * fw, dtype=dtype)
            j_ind = ps.range(j, fw, delta=strides, dtype=dtype)

            nc = cartesian_add([i_ind, j_ind])
            ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0])

            k = ps.reshape(cartesian_add([
                ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype),
                ps.range(ps.shape(nc)[1], dtype=dtype)
            ]),
                           shape=[-1])
            last_j = strides - (fw - j - 1) % strides - 1
            last_i = strides - (fh - i - 1) % strides - 1
            kernel_ind = ps.stack(
                [k, ps.ones_like(k) * last_i * strides + last_j], axis=1)
            event_ind = ps.tensor_scatter_nd_update(event_ind, ind[...,
                                                                   tf.newaxis],
                                                    kernel_ind)

            return i_ + 1, event_ind

        event_ind = ps.zeros((fh * fw, 2), dtype=dtype)
        _, event_ind = tf.while_loop(lambda i, _: i < strides**2, loop_body,
                                     [tf.zeros([], dtype=dtype), event_ind])

        tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding(
            fh, stride=strides, dilation=dh, padding=padding)
        tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding(
            fw, stride=strides, dilation=dw, padding=padding)

        pad_bottom = (tot_pad_bottom - 1) // strides + 1
        pad_top = (tot_pad_top - 1) // strides + 1
        pad_right = (tot_pad_right - 1) // strides + 1
        pad_left = (tot_pad_left - 1) // strides + 1
        padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right))

        truncate_top = pad_top * strides - tot_pad_top
        truncate_left = pad_left * strides - tot_pad_left

        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                if padding == 'VALID':
                    out_height = fh + strides * (xh - 1)
                    out_width = fw + strides * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * strides
                    out_width = xw * strides

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out

        return op
Esempio n. 17
0
def index_remapping_gather(params,
                           indices,
                           axis=0,
                           indices_axis=0,
                           name='index_remapping_gather'):
  """Gather values from `axis` of `params` using `indices_axis` of `indices`.

  The shape of `indices` must broadcast to that of `params` when
  their `indices_axis` and `axis` (respectively) are aligned:

  ```python
  # params.shape:
  [p[0],  ..., ...,         p[axis], ..., ..., p[rank(params)] - 1])
  # indices.shape:
        [i[0], ..., i[indices_axis], ..., i[rank(indices)] - 1])
  ```

  In particular, `params` must have at least as many
  leading dimensions as `indices` (`axis >= indices_axis`), and at least as many
  trailing dimensions (`rank(params) - axis >= rank(indices) - indices_axis`).

  The `result` has the same shape as `params`, except that the dimension
  of size `p[axis]` is replaced by one of size `i[indices_axis]`:

  ```python
  # result.shape:
  [p[0],  ..., ..., i[indices_axis], ..., ..., p[rank(params) - 1]]
  ```

  In the case where `rank(params) == 5`, `rank(indices) == 3`, `axis = 2`, and
  `indices_axis = 1`, the result is given by

   ```python
   # alignment is:                       v axis
   # params.shape    ==   [p[0], p[1], p[2], p[3], p[4]]
   # indices.shape   ==         [i[0], i[1], i[2]]
   #                                     ^ indices_axis
   result[i, j, k, l, m] = params[i, j, indices[j, k, l], l, m]
  ```

  Args:
    params:  `N-D` `Tensor` (`N > 0`) from which to gather values.
      Number of dimensions must be known statically.
    indices: `Tensor` with values in `{0, ..., params.shape[axis] - 1}`, whose
      shape broadcasts to that of `params` as described above.
    axis: Python `int` axis of `params` from which to gather.
    indices_axis: Python `int` axis of `indices` to align with the `axis`
      over which `params` is gathered.
    name: String name for scoping created ops.

  Returns:
    `Tensor` composed of elements of `params`.

  Raises:
    ValueError: If shape/rank requirements are not met.
  """
  with tf.name_scope(name):
    params = tf.convert_to_tensor(params, name='params')
    indices = tf.convert_to_tensor(indices, name='indices')

    params_ndims = tensorshape_util.rank(params.shape)
    indices_ndims = tensorshape_util.rank(indices.shape)
    # `axis` dtype must match ndims, which are 64-bit Python ints.
    axis = tf.get_static_value(ps.convert_to_shape_tensor(axis, dtype=tf.int64))
    indices_axis = tf.get_static_value(
        ps.convert_to_shape_tensor(indices_axis, dtype=tf.int64))

    if params_ndims is None:
      raise ValueError(
          'Rank of `params`, must be known statically. This is due to '
          'tf.gather not accepting a `Tensor` for `batch_dims`.')

    if axis is None:
      raise ValueError(
          '`axis` must be known statically. This is due to '
          'tf.gather not accepting a `Tensor` for `batch_dims`.')

    if indices_axis is None:
      raise ValueError(
          '`indices_axis` must be known statically. This is due to '
          'tf.gather not accepting a `Tensor` for `batch_dims`.')

    if indices_axis > axis:
      raise ValueError(
          '`indices_axis` should be <= `axis`, but was {} > {}'.format(
              indices_axis, axis))

    if params_ndims < 1:
      raise ValueError(
          'Rank of params should be `> 0`, but was {}'.format(params_ndims))

    if indices_ndims is not None and indices_ndims < 1:
      raise ValueError(
          'Rank of indices should be `> 0`, but was {}'.format(indices_ndims))

    if (indices_ndims is not None and
        (indices_ndims - indices_axis > params_ndims - axis)):
      raise ValueError(
          '`rank(params) - axis` ({} - {}) must be >= `rank(indices) - '
          'indices_axis` ({} - {}), but was not.'.format(
              params_ndims, axis, indices_ndims, indices_axis))

    # `tf.gather` requires the axis to be the rightmost batch ndim. So, we
    # transpose `indices_axis` to be the rightmost dimension of `indices`...
    transposed_indices = dist_util.move_dimension(indices,
                                                  source_idx=indices_axis,
                                                  dest_idx=-1)

    # ... and `axis` to be the corresponding (aligned as in the docstring)
    # dimension of `params`.
    broadcast_indices_ndims = indices_ndims + (axis - indices_axis)
    transposed_params = dist_util.move_dimension(
        params,
        source_idx=axis,
        dest_idx=broadcast_indices_ndims - 1)

    # Next we broadcast `indices` so that its shape has the same prefix as
    # `params.shape`.
    transposed_params_shape = ps.shape(transposed_params)
    result_shape = ps.concat([
        transposed_params_shape[:broadcast_indices_ndims - 1],
        ps.shape(indices)[indices_axis:indices_axis + 1],
        transposed_params_shape[broadcast_indices_ndims:]], axis=0)
    broadcast_indices = ps.broadcast_to(
        transposed_indices,
        result_shape[:broadcast_indices_ndims])

    result_t = tf.gather(transposed_params,
                         broadcast_indices,
                         batch_dims=broadcast_indices_ndims - 1,
                         axis=broadcast_indices_ndims - 1)
    return dist_util.move_dimension(result_t,
                                    source_idx=broadcast_indices_ndims - 1,
                                    dest_idx=axis)
Esempio n. 18
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                if padding == 'VALID':
                    out_height = fh + strides * (xh - 1)
                    out_width = fw + strides * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * strides
                    out_width = xw * strides

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out
Esempio n. 19
0
def concat_vectors(*args):
    """Concatenates input vectors, statically if possible."""
    args_ = [tf.get_static_value(x) for x in args]
    if any(vec is None for vec in args_):
        return tf.concat(args, axis=0)
    return [val for vec in args_ for val in vec]  # pylint: disable=g-complex-comprehension
Esempio n. 20
0
def make_convolution_transpose_fn_with_subkernels(filter_shape,
                                                  strides,
                                                  padding,
                                                  rank=2,
                                                  dilations=None,
                                                  dtype=tf.int32,
                                                  validate_args=False,
                                                  name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):

        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))
        [
            filter_shape,
            rank,
            strides,
            padding,
            dilations,
        ] = prepare_conv_args(filter_shape,
                              rank=rank,
                              strides=strides,
                              padding=padding,
                              dilations=dilations,
                              is_transpose=True,
                              validate_args=validate_args)

        sh, sw = strides
        fh, fw = filter_shape
        dh, dw = dilations

        # Determine maximum filter height and filter width of sub-kernels.
        sub_fh = (fh - 1) // sh + 1
        sub_fw = (fw - 1) // sw + 1

        def loop_body(i_, kernels_ind):
            i = i_ // sw
            j = i_ % sw
            i_ind = ps.range((sh - i - 1) * fw,
                             fw * fh,
                             delta=sh * fw,
                             dtype=dtype)
            j_ind = ps.range((sw - j - 1), fw, delta=sw, dtype=dtype)

            last_j = sw - (fw - j - 1) % sw - 1
            last_i = sh - (fh - i - 1) % sh - 1
            pos = last_i * sw + last_j

            nc = cartesian_add([i_ind, j_ind])
            kernels_ind = kernels_ind.write(
                sh * sw - pos - 1, ps.reverse(ps.reverse(nc, [0]), [1]))

            return i_ + 1, kernels_ind

        kernels_ind = tf.TensorArray(dtype=dtype,
                                     infer_shape=False,
                                     size=1,
                                     dynamic_size=True)

        _, kernels_ind = tf.while_loop(lambda i, _: i < sh * sw, loop_body,
                                       [0, kernels_ind])

        tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding(
            fh, stride=sh, dilation=dh, padding=padding)
        tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding(
            fw, stride=sw, dilation=dw, padding=padding)

        pad_bottom = (tot_pad_bottom - 1) // sh + 1
        pad_top = (tot_pad_top - 1) // sh + 1
        pad_right = (tot_pad_right - 1) // sw + 1
        pad_left = (tot_pad_left - 1) // sw + 1
        padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right))

        truncate_top = pad_top * sh - tot_pad_top
        truncate_left = pad_left * sw - tot_pad_left

        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  c_out, batch_shape,
                                                  event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)

                ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1
                ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1

                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs

                outputs = tf.TensorArray(dtype=input_dtype,
                                         infer_shape=False,
                                         size=1,
                                         dynamic_size=True)

                _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body,
                                           [0, outputs])

                y = outputs.concat()

                m = tf.reduce_prod(ps.shape(y)[:-3])
                y_ = tf.reshape(y,
                                shape=ps.concat([[m], ps.shape(y)[-3:]],
                                                axis=0))
                y2 = tf.batch_to_space(y_,
                                       strides,
                                       crops=tf.zeros([2, 2], dtype=tf.int64))
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)
                y2 = tf.reshape(
                    y2,
                    ps.concat([broadcast_batch_shape,
                               ps.shape(y2)[-3:]],
                              axis=0))

                if padding == 'VALID':
                    out_height = fh + sh * (xh - 1)
                    out_width = fw + sw * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * sh
                    out_width = xw * sw

                return y2[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]

        return op
Esempio n. 21
0
def _replace_event_shape_in_tensorshape(input_tensorshape, event_shape_in,
                                        event_shape_out):
    """Replaces the event shape dims of a `TensorShape`.

  Args:
    input_tensorshape: a `TensorShape` instance in which to attempt replacing
      event shape.
    event_shape_in: `Tensor` shape representing the event shape expected to
      be present in (rightmost dims of) `tensorshape_in`. Must be compatible
      with the rightmost dims of `tensorshape_in`.
    event_shape_out: `Tensor` shape representing the new event shape, i.e.,
      the replacement of `event_shape_in`,

  Returns:
    output_tensorshape: `TensorShape` with the rightmost `event_shape_in`
      replaced by `event_shape_out`. Might be partially defined, i.e.,
      `TensorShape(None)`.
    is_validated: Python `bool` indicating static validation happened.

  Raises:
    ValueError: if we can determine the event shape portion of
      `tensorshape_in` as well as `event_shape_in` both statically, and they
      are not compatible. "Compatible" here means that they are identical on
      any dims that are not -1 in `event_shape_in`.
  """
    event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape)
    if tensorshape_util.rank(
            input_tensorshape) is None or event_shape_in_ndims is None:
        return tf.TensorShape(None), False  # Not is_validated.

    input_non_event_ndims = tensorshape_util.rank(
        input_tensorshape) - event_shape_in_ndims
    if input_non_event_ndims < 0:
        raise ValueError(
            'Input has lower rank ({}) than `event_shape_ndims` ({}).'.format(
                tensorshape_util.rank(input_tensorshape),
                event_shape_in_ndims))

    input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims]
    input_event_tensorshape = input_tensorshape[input_non_event_ndims:]

    # Check that `input_event_shape_` and `event_shape_in` are compatible in the
    # sense that they have equal entries in any position that isn't a `-1` in
    # `event_shape_in`. Note that our validations at construction time ensure
    # there is at most one such entry in `event_shape_in`.
    event_shape_in_ = tf.get_static_value(event_shape_in)
    is_validated = (tensorshape_util.is_fully_defined(input_event_tensorshape)
                    and event_shape_in_ is not None)
    if is_validated:
        input_event_shape_ = np.int32(input_event_tensorshape)
        mask = event_shape_in_ >= 0
        explicit_input_event_shape_ = input_event_shape_[mask]
        explicit_event_shape_in_ = event_shape_in_[mask]
        if not np.all(explicit_input_event_shape_ == explicit_event_shape_in_):
            raise ValueError(
                'Input `event_shape` does not match `event_shape_in` '
                '({} vs {}).'.format(input_event_shape_, event_shape_in_))

    event_tensorshape_out = tensorshape_util.constant_value_as_shape(
        event_shape_out)
    if tensorshape_util.rank(event_tensorshape_out) is None:
        output_tensorshape = tf.TensorShape(None)
    else:
        output_tensorshape = tensorshape_util.concatenate(
            input_non_event_tensorshape, event_tensorshape_out)

    return output_tensorshape, is_validated
Esempio n. 22
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  c_out, batch_shape,
                                                  event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)

                ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1
                ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1

                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs

                outputs = tf.TensorArray(dtype=input_dtype,
                                         infer_shape=False,
                                         size=1,
                                         dynamic_size=True)

                _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body,
                                           [0, outputs])

                y = outputs.concat()

                m = tf.reduce_prod(ps.shape(y)[:-3])
                y_ = tf.reshape(y,
                                shape=ps.concat([[m], ps.shape(y)[-3:]],
                                                axis=0))
                y2 = tf.batch_to_space(y_,
                                       strides,
                                       crops=tf.zeros([2, 2], dtype=tf.int64))
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)
                y2 = tf.reshape(
                    y2,
                    ps.concat([broadcast_batch_shape,
                               ps.shape(y2)[-3:]],
                              axis=0))

                if padding == 'VALID':
                    out_height = fh + sh * (xh - 1)
                    out_width = fw + sw * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * sh
                    out_width = xw * sw

                return y2[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
Esempio n. 23
0
  def __init__(self,
               output_shape=(32, 32, 3),
               num_glow_blocks=3,
               num_steps_per_block=32,
               coupling_bijector_fn=None,
               exit_bijector_fn=None,
               grab_after_block=None,
               use_actnorm=True,
               seed=None,
               validate_args=False,
               name='glow'):
    """Creates the Glow bijector.

    Args:
      output_shape: A list of integers, specifying the event shape of the
        output, of the bijectors forward pass (the image).  Specified as
        [H, W, C].
        Default Value: (32, 32, 3)
      num_glow_blocks: An integer, specifying how many downsampling levels to
        include in the model. This must divide equally into both H and W,
        otherwise the bijector would not be invertible.
        Default Value: 3
      num_steps_per_block: An integer specifying how many Affine Coupling and
        1x1 convolution layers to include at each level of the spatial
        hierarchy.
        Default Value: 32 (i.e. the value used in the original glow paper).
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras.Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
      exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is
        a function which takes the argument `input_shape` and `output_chan`
        and returns a callable neural network. The neural network it returns
        should take a tensor of shape `input_shape` as the input, and return
        one of three options: A tensor with `output_chan` channels, a tensor
        with `2 * output_chan` channels, or a bijector. Additional details can
        be found in the documentation for ExitBijector.
      grab_after_block: A tuple of floats, specifying what fraction of the
        remaining channels to remove following each glow block. Glow will take
        the integer floor of this number multiplied by the remaining number of
        channels. The default is half at each spatial hierarchy.
        Default value: None (this will take out half of the channels after each
          block.
      use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent
        initialization is used to initialize this layer.
        Default value: `False`
      seed: A seed to control randomness in the 1x1 convolution initialization.
        Default value: `None` (i.e., non-reproducible sampling).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False`
      name: Python `str`, name given to ops managed by this object.
        Default value: `'glow'`.
    """
    # Make sure that the input shape is fully defined.
    if not tensorshape_util.is_fully_defined(output_shape):
      raise ValueError('Shape must be fully defined.')
    if tensorshape_util.rank(output_shape) != 3:
      raise ValueError('Shape ndims must be 3 for images.  Your shape is'
                       '{}'.format(tensorshape_util.rank(output_shape)))

    num_glow_blocks_ = tf.get_static_value(num_glow_blocks)
    if (num_glow_blocks_ is None or
        int(num_glow_blocks_) != num_glow_blocks_ or
        num_glow_blocks_ < 1):
      raise ValueError('Argument `num_glow_blocks` must be a statically known'
                       'positive `int` (saw: {}).'.format(num_glow_blocks))
    num_glow_blocks = int(num_glow_blocks_)

    output_shape = tensorshape_util.as_list(output_shape)
    h, w, c = output_shape
    n = num_glow_blocks
    nsteps = num_steps_per_block

    # Default Glow: Half of the channels are split off after each block,
    # and after the final block, no channels are split off.
    if grab_after_block is None:
      grab_after_block = tuple([0.5] * (n - 1) + [0.])

    # Thing we know must be true: h and w are evenly divisible by 2, n times.
    # Otherwise, the squeeze bijector will not work.
    if w % 2**n != 0:
      raise ValueError('Width must be divisible by 2 at least n times.'
                       'Saw: {} % {} != 0'.format(w, 2**n))
    if h % 2**n != 0:
      raise ValueError('Height should be divisible by 2 at least n times.')
    if h // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image height '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, h,
                                       int(np.log(h) / np.log(2.))))
    if w // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image width '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, w,
                                       int(np.log(h) / np.log(2.))))

    # Other things we want to be true:
    # - The number of times we take must be equal to the number of glow blocks.
    if len(grab_after_block) != num_glow_blocks:
      raise ValueError('Length of grab_after_block ({0}) must match the number'
                       'of blocks ({1}).'.format(len(grab_after_block),
                                                 num_glow_blocks))

    self._blockwise_splits = self._get_blockwise_splits(output_shape,
                                                        grab_after_block[::-1])

    # Now check on the values of blockwise splits
    if any([bs[0] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[0] for bs in self._blockwise_splits].index(True)
      raise ValueError('At at least one exit, you are taking out all of your '
                       'channels, and therefore have no inputs to later blocks.'
                       ' Try setting grab_after_block to a lower value at index'
                       '{}.'.format(first_offender))

    if any(np.isclose(gab, 0) for gab in grab_after_block):
      # Special case: if specifically exiting no channels, then the exit is
      # just an identity bijector.
      pass
    elif any([bs[1] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[1] for bs in self._blockwise_splits].index(True)
      raise ValueError('At least one of your layers has < 1 output channels. '
                       'This means you set grab_at_block too small. '
                       'Try setting grab_after_block to a larger value at index'
                       '{}.'.format(first_offender))

    # Lets start to build our bijector. We assume that the distribution is 1
    # dimensional. First, lets reshape it to an image.
    glow_chain = [
        reshape.Reshape(
            event_shape_out=[h // 2**n, w // 2**n, c * 4**n],
            event_shape_in=[h * w * c])
    ]

    seedstream = SeedStream(seed=seed, salt='random_beta')

    for i in range(n):

      # This is the shape of the current tensor
      current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1))

      # This is the shape of the input to both the glow block and exit bijector.
      this_nchan = sum(self._blockwise_splits[i][0:2])
      this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan)

      glow_chain.append(invert.Invert(ExitBijector(current_shape,
                                                   self._blockwise_splits[i],
                                                   exit_bijector_fn)))

      glow_block = GlowBlock(input_shape=this_input_shape,
                             num_steps=nsteps,
                             coupling_bijector_fn=coupling_bijector_fn,
                             use_actnorm=use_actnorm,
                             seedstream=seedstream)

      if self._blockwise_splits[i][2] == 0:
        # All channels are passed to the RealNVP
        glow_chain.append(glow_block)
      else:
        # Some channels are passed around the block.
        # This is done with the Blockwise bijector.
        glow_chain.append(
            blockwise.Blockwise(
                [glow_block, identity.Identity()],
                [sum(self._blockwise_splits[i][0:2]),
                 self._blockwise_splits[i][2]]))

      # Finally, lets expand the channels into spatial features.
      glow_chain.append(
          Expand(input_shape=[
              h // 2**n * 2**i,
              w // 2**n * 2**i,
              c * 4**n // 4**i,
          ]))

    glow_chain = glow_chain[::-1]
    # To finish off, we initialize the bijector with the chain we've built
    # This way, the rest of the model attributes are taken care of for us.
    super(Glow, self).__init__(
        bijectors=glow_chain, validate_args=validate_args, name=name)
Esempio n. 24
0
def im2row_index(input_shape,
                 block_shape,
                 rank=2,
                 slice_step=(1, 1),
                 dilations=(1, 1),
                 dtype=tf.int32,
                 transpose=False,
                 validate_args=False,
                 name=None):
    """Computes indexes into a flattened image for building `im2row`."""
    with tf.name_scope(name or 'im2row_index'):
        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))
        fh, fw = prepare_tuple_argument(block_shape,
                                        n=rank,
                                        arg_name='block_shape',
                                        validate_args=validate_args)
        sh, sw = prepare_tuple_argument(slice_step,
                                        n=rank,
                                        arg_name='slice_step',
                                        validate_args=validate_args)
        dh, dw = prepare_tuple_argument(dilations,
                                        n=rank,
                                        arg_name='dilations',
                                        validate_args=validate_args)

        # 1) Process input arguments.
        batch_shape, h, w, c = ps.split(ps.reshape(ps.cast(input_shape,
                                                           dtype=dtype),
                                                   shape=[-1]),
                                        num_or_size_splits=[-1, 1, 1, 1])
        h, w, c = h[0], w[0], c[0]

        tot_fh = dh * (fh - 1) + 1
        tot_fw = dw * (fw - 1) + 1

        # 2) Assemble all block start positions as indexes into the flattened image.
        # start_idx.shape = [fh, fw, c]
        if transpose:
            last_element = lambda size, step: size - (size - 1) % step - 1
            w_step = c * dw
            h_step = c * w * dh
            last_w = last_element(c * tot_fw, w_step)
            last_h = last_element(c * w * tot_fh, h_step)
            start_idx = cartesian_add([
                ps.range(last_h, -1, delta=-h_step, dtype=dtype),
                ps.range(last_w, -1, delta=-w_step, dtype=dtype),
                ps.range(c, delta=1, dtype=dtype),
            ])
        else:
            start_idx = cartesian_add([
                ps.range(c * w * tot_fh, delta=c * w * dh, dtype=dtype),
                ps.range(c * tot_fw, delta=c * dw, dtype=dtype),
                ps.range(c, delta=1, dtype=dtype),
            ])

        # 3) Assemble all block offsets (into flattened image).
        eh = h - tot_fh + 1
        ew = w - tot_fw + 1

        offset_idx = cartesian_add([
            ps.range(w * eh, delta=w * sh, dtype=dtype),
            ps.range(ew, delta=sw, dtype=dtype),
        ])

        offset_idx = offset_idx * c
        oh = (eh - 1) // sh + 1  # out height
        ow = (ew - 1) // sw + 1  # out width

        # 4) Combine block start/offset pairs.
        # shape = [(eh // sh) * (ew // sw), fh * fw * c]
        idx = cartesian_add([offset_idx, start_idx])
        new_shape = ps.concat(
            [batch_shape,
             ps.convert_to_shape_tensor([oh, ow, fh * fw * c])],
            axis=0)
        return idx, new_shape
Esempio n. 25
0
    def __init__(self,
                 num_or_size_splits,
                 axis=-1,
                 validate_args=False,
                 name='split'):
        """Creates the bijector.

    Args:
      num_or_size_splits: Either a Python integer indicating the number of
        splits along `axis` or a 1-D integer `Tensor` or Python list containing
        the sizes of each output tensor along `axis`. If a list/`Tensor`, it may
        contain at most one value of `-1`, which indicates a split size that is
        unknown and determined from input.
      axis: A negative integer or scalar `int32` `Tensor`. The dimension along
        which to split. Must be negative to enable the bijector to support
        arbitrary batch dimensions. Defaults to -1 (note that this is different
        from the `tf.Split` default of `0`). Must be statically known.
      validate_args: Python `bool` indicating whether arguments should
        be checked for correctness.
      name: Python `str`, name given to ops managed by this object.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:

            if isinstance(num_or_size_splits, numbers.Integral):
                self._num_splits = num_or_size_splits
                self._split_sizes = None
            else:
                self._split_sizes = tensor_util.convert_nonref_to_tensor(
                    num_or_size_splits,
                    name='num_or_size_splits',
                    dtype=tf.int32)

                if tensorshape_util.rank(self._split_sizes.shape) != 1:
                    raise ValueError(
                        '`num_or_size_splits` must be an integer or 1-D `Tensor`.'
                    )

                num_splits = tensorshape_util.as_list(
                    self._split_sizes.shape)[0]
                if num_splits is None:
                    raise ValueError(
                        'If `num_or_size_splits` is a vector of split sizes '
                        'it must have a statically-known number of '
                        'elements.')
                self._num_splits = num_splits

            static_axis = tf.get_static_value(axis)
            if static_axis is None:
                raise ValueError('`axis` must be statically known.')
            if static_axis >= 0:
                raise ValueError(
                    '`axis` must be negative. Got {}'.format(axis))

            self._axis = tf.convert_to_tensor(axis, tf.int32)

            super(Split, self).__init__(forward_min_event_ndims=-axis,
                                        inverse_min_event_ndims=[-axis] *
                                        self.num_splits,
                                        is_constant_jacobian=True,
                                        validate_args=validate_args,
                                        parameters=parameters,
                                        name=name)
Esempio n. 26
0
def prepare_conv_args(filter_shape,
                      rank,
                      strides,
                      padding,
                      dilations,
                      is_transpose=False,
                      validate_args=False):
    """Sanitizes use provided input."""
    padding = _validate_padding(padding)
    try:
        rank = int(tf.get_static_value(rank))
    except TypeError:
        raise TypeError('Argument `rank` must be statically known `int`.')
    valid_rank = {1, 2, 3}
    if rank not in valid_rank:
        raise ValueError('Argument `rank` must be in {}.'.format(valid_rank))
    filter_shape = prepare_tuple_argument(filter_shape,
                                          n=rank,
                                          arg_name='filter_shape',
                                          validate_args=validate_args)
    strides = prepare_tuple_argument(strides,
                                     n=rank,
                                     arg_name='strides',
                                     validate_args=validate_args)
    padding = _prepare_padding_argument(padding)
    dilations = prepare_tuple_argument(dilations,
                                       n=rank,
                                       arg_name='dilations',
                                       validate_args=validate_args)

    strides_ = [tf.get_static_value(s) for s in strides]
    dilations_ = [tf.get_static_value(d) for d in dilations]
    assertions = []
    if is_transpose:
        if (all(s is not None for s in strides_)
                and all(d is not None for d in dilations_)):
            if any(s > 1 for s in strides_) and any(d > 1 for d in dilations_):
                raise NotImplementedError(
                    'At least one of `dilations` and `strides` '
                    'must equal `1` for each dimension. Saw: '
                    '`strides={}`, `dilations={}`'.format(strides, dilations))
        elif validate_args:
            assertions.append(
                assert_util.assert_equal(
                    tf.logical_or(tf.equal(tf.reduce_max(strides), 1),
                                  tf.equal(tf.reduce_max(dilations), 1)),
                    True,
                    message=
                    'At least one of `dilations` and `strides` must equal `1` '
                    'for each dimension.'))

        # TODO(emilyaf): Remove this once strides > filter_dim is supported.
        filter_shape_ = [tf.get_static_value(s) for s in filter_shape]
        if any(s is not None and f is not None and s > f
               for s, f in zip(strides_, filter_shape_)):
            raise NotImplementedError(
                'Stride must be less than or equal to the '
                'filter size along each dimension.')

    with tf.control_dependencies(assertions):
        return filter_shape, rank, strides, padding, dilations
Esempio n. 27
0
  def __init__(self,
               component_ssms,
               constant_offset=0.,
               observation_noise_scale=None,
               initial_state_prior=None,
               initial_step=0,
               validate_args=False,
               name=None,
               **linear_gaussian_ssm_kwargs):
    """Build a state space model representing the sum of component models.

    Args:
      component_ssms: Python `list` containing one or more
        `tfd.LinearGaussianStateSpaceModel` instances. The components
        will in general implement different time-series models, with possibly
        different `latent_size`, but they must have the same `dtype`, event
        shape (`num_timesteps` and `observation_size`), and their batch shapes
        must broadcast to a compatible batch shape.
      constant_offset: `float` `Tensor` of shape broadcasting to
        `concat([batch_shape, [num_timesteps]]`) specifying a constant value
        added to the sum of outputs from the component models. This allows the
        components to model the shifted series
        `observed_time_series - constant_offset`.
        Default value: `0.`
      observation_noise_scale: Optional scalar `float` `Tensor` indicating the
        standard deviation of the observation noise. May contain additional
        batch dimensions, which must broadcast with the batch shape of elements
        in `component_ssms`. If `observation_noise_scale` is specified for the
        `AdditiveStateSpaceModel`, the observation noise scales of component
        models are ignored. If `None`, the observation noise scale is derived
        by summing the noise variances of the component models, i.e.,
        `observation_noise_scale = sqrt(sum(
        [ssm.observation_noise_scale**2 for ssm in component_ssms]))`.
      initial_state_prior: Optional instance of `tfd.MultivariateNormal`
        representing a prior distribution on the latent state at time
        `initial_step`. If `None`, defaults to the independent priors from
        component models, i.e.,
        `[component.initial_state_prior for component in component_ssms]`.
        Default value: `None`.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: 0.
      validate_args: Python `bool`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
        Default value: `False`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "AdditiveStateSpaceModel".
      **linear_gaussian_ssm_kwargs: Optional additional keyword arguments to
        to the base `tfd.LinearGaussianStateSpaceModel` constructor.
    Raises:
      ValueError: if components have different `num_timesteps`.
    """
    parameters = dict(locals())
    parameters.update(linear_gaussian_ssm_kwargs)
    del parameters['linear_gaussian_ssm_kwargs']
    with tf.name_scope(name or 'AdditiveStateSpaceModel') as name:
      # Check that all components have the same dtype
      dtype = tf.debugging.assert_same_float_dtype(component_ssms)

      # Convert scalar offsets to canonical shape `[..., num_timesteps]`.
      constant_offset = (tf.convert_to_tensor(value=constant_offset,
                                              name='constant_offset',
                                              dtype=dtype) *
                         tf.ones([1], dtype=dtype))
      offset_length = ps.shape(constant_offset)[-1]
      assertions = []

      # Construct an initial state prior as a block-diagonal combination
      # of the component state priors.
      if initial_state_prior is None:
        initial_state_prior = sts_util.factored_joint_mvn(
            [ssm.initial_state_prior for ssm in component_ssms])
      dtype = initial_state_prior.dtype

      static_num_timesteps = [
          tf.get_static_value(ssm.num_timesteps)
          for ssm in component_ssms
          if tf.get_static_value(ssm.num_timesteps) is not None
      ]

      # If any components have a static value for `num_timesteps`, use that
      # value for the additive model. (and check that all other static values
      # match it).
      if static_num_timesteps:
        num_timesteps = static_num_timesteps[0]
        if not all([component_timesteps == num_timesteps
                    for component_timesteps in static_num_timesteps]):
          raise ValueError('Additive model components must all have the same '
                           'number of timesteps '
                           '(saw: {})'.format(static_num_timesteps))
      else:
        num_timesteps = component_ssms[0].num_timesteps
      if validate_args and len(static_num_timesteps) != len(component_ssms):
        assertions += [
            tf.debugging.assert_equal(  # pylint: disable=g-complex-comprehension
                num_timesteps,
                ssm.num_timesteps,
                message='Additive model components must all have '
                'the same number of timesteps.') for ssm in component_ssms
        ]

      # Define the transition and observation models for the additive SSM.
      # See the "mathematical details" section of the class docstring for
      # further information. Note that we define these as callables to
      # handle the fully general case in which some components have time-
      # varying dynamics.
      def transition_matrix_fn(t):
        return tfl.LinearOperatorBlockDiag(
            [ssm.get_transition_matrix_for_timestep(t)
             for ssm in component_ssms])

      def transition_noise_fn(t):
        return sts_util.factored_joint_mvn(
            [ssm.get_transition_noise_for_timestep(t)
             for ssm in component_ssms])

      # Build the observation matrix, concatenating (broadcast) observation
      # matrices from components. We also take this as an opportunity to enforce
      # any dynamic assertions we may have generated above.
      broadcast_batch_shape = ps.cast(
          sts_util.broadcast_batch_shape(
              [ssm.get_observation_matrix_for_timestep(initial_step)
               for ssm in component_ssms]), dtype=tf.int32)
      broadcast_obs_matrix = tf.ones(
          ps.concat([broadcast_batch_shape, [1, 1]], axis=0), dtype=dtype)
      if assertions:
        with tf.control_dependencies(assertions):
          broadcast_obs_matrix = tf.identity(broadcast_obs_matrix)

      def observation_matrix_fn(t):
        return tfl.LinearOperatorFullMatrix(
            tf.concat([ssm.get_observation_matrix_for_timestep(t).to_dense() *
                       broadcast_obs_matrix for ssm in component_ssms],
                      axis=-1))

      # Broadcast the constant offset across timesteps.
      offset_at_step = lambda t: (  # pylint: disable=g-long-lambda
          constant_offset if offset_length == 1
          else tf.gather(constant_offset, tf.minimum(t, offset_length - 1),
                         axis=-1)[..., tf.newaxis])

      if observation_noise_scale is not None:
        observation_noise_scale = tf.convert_to_tensor(
            value=observation_noise_scale,
            name='observation_noise_scale',
            dtype=dtype)
        def observation_noise_fn(t):
          return tfd.MultivariateNormalDiag(
              loc=(sum([ssm.get_observation_noise_for_timestep(t).mean()
                        for ssm in component_ssms]) + offset_at_step(t)),
              scale_diag=observation_noise_scale[..., tf.newaxis])
      else:
        def observation_noise_fn(t):
          offset = offset_at_step(t)
          return sts_util.sum_mvns(
              [tfd.MultivariateNormalDiag(
                  loc=offset,
                  scale_diag=tf.zeros_like(offset))] +
              [ssm.get_observation_noise_for_timestep(t)
               for ssm in component_ssms])

      super(AdditiveStateSpaceModel, self).__init__(
          num_timesteps=num_timesteps,
          transition_matrix=transition_matrix_fn,
          transition_noise=transition_noise_fn,
          observation_matrix=observation_matrix_fn,
          observation_noise=observation_noise_fn,
          initial_state_prior=initial_state_prior,
          initial_step=initial_step,
          validate_args=validate_args,
          name=name,
          **linear_gaussian_ssm_kwargs)
      self._parameters = parameters
Esempio n. 28
0
def _rank(input, name=None):  # pylint: disable=redefined-builtin,unused-argument
    if not hasattr(input, 'shape'):
        input = (tf.convert_to_tensor(input)
                 if tf.get_static_value(input) is None else np.array(input))
    ndims_ = tensorshape_util.rank(getattr(input, 'shape', None))
    return tf.rank(input) if ndims_ is None else np.int32(ndims_)
Esempio n. 29
0
    def _sample_n(self, n, seed):
        components_seed, mix_seed = samplers.split_seed(
            seed, salt='MixtureSameFamily')
        try:
            seed_stream = SeedStream(seed, salt='MixtureSameFamily')
        except TypeError as e:  # Can happen for Tensor seeds.
            seed_stream = None
            seed_stream_err = e
        try:
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=components_seed)
            if seed_stream is not None:
                seed_stream()  # Advance even if unused.
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `components_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. {}')
            warnings.warn(
                msg.format(self.components_distribution.name,
                           type(self.components_distribution), str(e)))
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=seed_stream())

        event_shape = None
        event_ndims = tensorshape_util.rank(self.event_shape)
        if event_ndims is None:
            event_shape = self.components_distribution.event_shape_tensor()
            event_ndims = prefer_static.rank_from_shape(event_shape)
        event_ndims_static = tf.get_static_value(event_ndims)

        num_components = None
        if event_ndims_static is not None:
            num_components = tf.compat.dimension_value(
                x.shape[-1 - event_ndims_static])
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        if num_components is None:
            num_components = tf.shape(x)[-1 - event_ndims]

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        try:
            mix_sample = self.mixture_distribution.sample(
                n, seed=mix_seed)  # [n, B] or [n]
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `mixture_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(self.mixture_distribution.name,
                           type(self.mixture_distribution), str(e)))
            mix_sample = self.mixture_distribution.sample(
                n, seed=seed_stream())  # [n, B] or [n]
        mask = tf.one_hot(
            indices=mix_sample,  # [n, B] or [n]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k] or [n, k]

        # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] .
        batch_ndims = prefer_static.rank(x) - event_ndims - 1
        mask_batch_ndims = prefer_static.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = prefer_static.shape(mask)
        mask = tf.reshape(
            mask,
            shape=prefer_static.concat([
                mask_shape[:-1],
                prefer_static.ones([pad_ndims], dtype=tf.int32),
                mask_shape[-1:],
                prefer_static.ones([event_ndims], dtype=tf.int32),
            ],
                                       axis=0))

        if x.dtype in [
                tf.bfloat16, tf.float16, tf.float32, tf.float64, tf.complex64,
                tf.complex128
        ]:
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        ret = tf.reduce_sum(masked, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            if event_shape is None:
                event_shape = self.components_distribution.event_shape_tensor()
            ret = self._reparameterize_sample(ret, event_shape=event_shape)

        return ret
Esempio n. 30
0
 def tensor_and_const_value(v):
     tensor_value = tf.convert_to_tensor(v)
     const_value = tf.get_static_value(tensor_value)
     return (tensor_value, const_value)