Esempio n. 1
0
 def _should_broadcast(self, obj):
     # e.g. 'mse'.
     if not nest.is_nested(obj):
         return True
     # e.g. ['mse'] or ['mse', 'mae'].
     return (isinstance(obj, (list, tuple))
             and not any(nest.is_nested(o) for o in obj))
Esempio n. 2
0
 def call(self, inputs, states, training=None):
   prev_output = states[0] if nest.is_nested(states) else states
  
   output = K.dot(inputs, self.kernel)
   print (f"output:{output}")
   
   new_state = [output] if nest.is_nested(states) else output
   return output, new_state
Esempio n. 3
0
def _coerce_structure(shallow_tree, input_tree):
    """Implementation of coerce_structure."""
    if not nest.is_nested(shallow_tree):
        return input_tree

    if not nest.is_nested(input_tree):
        raise TypeError(
            nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree)))

    if len(input_tree) != len(shallow_tree):
        raise ValueError(
            nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
                input_length=len(input_tree),
                shallow_length=len(shallow_tree)))

    # Determine whether shallow_tree should be treated as a Mapping or a Sequence.
    # Namedtuples can be interpreted either way (but keys take precedence).
    _shallow_is_namedtuple = nest._is_namedtuple(shallow_tree)  # pylint: disable=invalid-name
    _shallow_is_mapping = isinstance(shallow_tree, collections.abc.Mapping)  # pylint: disable=invalid-name
    shallow_supports_keys = _shallow_is_namedtuple or _shallow_is_mapping
    shallow_supports_iter = _shallow_is_namedtuple or not _shallow_is_mapping

    # Branch-selection depends on both shallow and input container-classes.
    input_is_mapping = isinstance(input_tree, collections.abc.Mapping)
    if nest._is_namedtuple(input_tree):
        if shallow_supports_keys:
            lookup_branch = lambda k: getattr(input_tree, k)
        else:
            input_iter = nest._yield_value(input_tree)
            lookup_branch = lambda _: next(input_iter)
    elif shallow_supports_keys and input_is_mapping:
        lookup_branch = lambda k: input_tree[k]
    elif shallow_supports_iter and not input_is_mapping:
        input_iter = nest._yield_value(input_tree)
        lookup_branch = lambda _: next(input_iter)
    else:
        raise TypeError(
            nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format(
                input_type=type(input_tree),
                shallow_type=(type(shallow_tree.__wrapped__) if hasattr(
                    shallow_tree, '__wrapped__') else type(shallow_tree))))

    flat_coerced = []
    needs_wrapping = type(shallow_tree) is not type(input_tree)
    for shallow_key, shallow_branch in nest._yield_sorted_items(shallow_tree):
        try:
            input_branch = lookup_branch(shallow_key)
        except (KeyError, AttributeError):
            raise ValueError(
                nest._SHALLOW_TREE_HAS_INVALID_KEYS.format([shallow_key]))
        flat_coerced.append(_coerce_structure(shallow_branch, input_branch))
        # Keep track of whether nested elements have changed.
        needs_wrapping |= input_branch is not flat_coerced[-1]

    # Only create a new instance if containers differ or contents changed.
    return (nest._sequence_like(shallow_tree, flat_coerced)
            if needs_wrapping else input_tree)
Esempio n. 4
0
def _prepare_args(target, event_ndims):
    """Creates a structure of `RunningCovariance`s based on inferred metadata.

  Metadata required to create a `RunningCovariance` object (`shape`, `dtype`,
  and `event_ndims` of incoming chain states) will be inferred from the
  `target`. Using that information, an identical structure of
  `RunningCovariance`s to `target` will be returned.

  Args:
    target: A (possibly nested) structure of `Tensor`s or Python
      `list`s of `Tensor`s representing the current state(s) of the Markov
      chain(s). It is used to infer the shape and dtype of future samples.
    event_ndims: A (possibly nested) structure of integers. Defines
        the number of inner-most dimensions that represent the event shape.
        Must be either a singleton or of the same shape as `target`.

  Returns:
    cov_streams: Structure of `sample_stats.RunningCovariance` matching
      the shape of `target`.
  """

    shape = tf.nest.map_structure(lambda target: target.shape, target)
    dtype = tf.nest.map_structure(lambda target: target.dtype, target)
    if event_ndims is None:
        event_ndims = tf.nest.map_structure(ps.rank, target)
    elif not nest.is_nested(event_ndims):
        event_ndims = nest_util.broadcast_structure(target, event_ndims)
    return nest.map_structure_up_to(
        target,
        sample_stats.RunningCovariance,
        shape,
        event_ndims,
        dtype,
        check_types=False,
    )
    def one_step(self,
                 new_chain_state,
                 current_reducer_state,
                 previous_kernel_results,
                 axis=None):
        """Update the `current_reducer_state` with a new chain state.

    Chunking semantics are similar to those of batching and are specified by the
    `axis` parameter. If chunking is enabled (axis is not `None`), all elements
    along the specified `axis` will be treated as separate samples. If a
    single scalar value is provided for a non-scalar sample structure, that
    value will be used for all elements in the structure. If not, an identical
    structure must be provided.

    Args:
      new_chain_state: A (possibly nested) structure of incoming chain state(s)
        with shape and dtype compatible with those used to initialize the
        `current_reducer_state`.
      current_reducer_state: `CovarianceReducerState`s representing the current
        state of the running covariance.
      previous_kernel_results: A (possibly nested) structure of `Tensor`s
        representing internal calculations made in a related
        `TransitionKernel`.
      axis: If chunking is desired, this is a (possibly nested) structure of
        integers that specifies the axis with chunked samples. For individual
        samples, set this to `None`. By default, samples are not chunked
        (`axis` is None).

    Returns:
      new_reducer_state: `CovarianceReducerState` with updated running
        statistics. Its `cov_state` field has an identical structure to the
        results of `self.transform_fn`. Each of the individual values in that
        structure subsequently mimics the structure of `current_reducer_state`.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'covariance_reducer',
                                    'one_step')):
            cov_streams = _prepare_args(current_reducer_state.init_structure,
                                        self.event_ndims)
            new_chain_state = tf.nest.map_structure(tf.convert_to_tensor,
                                                    new_chain_state)
            previous_kernel_results = tf.nest.map_structure(
                tf.convert_to_tensor, previous_kernel_results)
            fn_results = tf.nest.map_structure(
                lambda fn: fn(new_chain_state, previous_kernel_results),
                self.transform_fn,
            )
            if not nest.is_nested(axis):
                axis = nest_util.broadcast_structure(fn_results, axis)
            running_cov_state = nest.map_structure_up_to(
                current_reducer_state.init_structure,
                lambda strm, *args: strm.update(*args),
                cov_streams,
                current_reducer_state.cov_state,
                fn_results,
                axis,
                check_types=False,
            )
            return CovarianceReducerState(current_reducer_state.init_structure,
                                          running_cov_state)
Esempio n. 6
0
    def call(self, inputs, states, training=None):
        verbose = False
        old_is_init, old_forward, old_loglik = states
        batch_size = old_forward.shape[0]
        if verbose:
            print("batch_size=", batch_size)
            print("old_is_init=", old_is_init)
            print("old_forward=\n", old_forward, " shape", old_forward.shape)
            print("old_loglik=", old_loglik)

        I0 = tf.dtypes.cast(old_is_init, tf.float32)
        R0 = tf.tensordot(I0, self.I, axes=0)
        R1 = tf.linalg.matvec(self.A, old_forward, transpose_a=True)
        R = R0 + R1
        R = tf.identity(R, name="R")
        if verbose:
            print(f"R0:{R0}\nR1:{R1}\nR:{R}")
        #   [units, n, s]     [batch_size, s]
        # E has shape [batch_size, units, n]
        E = tf.linalg.matvec(tf.expand_dims(self.B, 0),
                             tf.expand_dims(inputs, 1),
                             name="E")
        forward = tf.multiply(E, R, name="forward")
        S = tf.reduce_sum(forward, axis=-1, name="loglik")
        loglik = old_loglik + tf.math.log(S)
        forward = forward / tf.expand_dims(S, -1)
        batch_size = tf.shape(inputs)[
            0]  # 'call' can be given None as batch size
        is_init = tf.zeros(batch_size, dtype='int8', name="is_init")
        new_state = [is_init, forward, loglik]
        new_state = [new_state] if nest.is_nested(states) else new_state
        if verbose:
            print("new_state", new_state)
        return loglik, new_state
Esempio n. 7
0
def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
    """Maps the atomic elements of a nested structure.

  Args:
    is_atomic_fn: A function that determines if an element of `nested` is
      atomic.
    map_fn: The function to apply to atomic elements of `nested`.
    nested: A nested structure.

  Returns:
    The nested structure, with atomic elements mapped according to `map_fn`.

  Raises:
    ValueError: If an element that is neither atomic nor a sequence is
      encountered.
  """
    if is_atomic_fn(nested):
        return map_fn(nested)

    # Recursively convert.
    if not nest.is_nested(nested):
        raise ValueError(
            'Received non-atomic and non-sequence element: {}'.format(nested))
    if nest.is_mapping(nested):
        values = [nested[k] for k in sorted(nested.keys())]
    elif nest.is_attrs(nested):
        values = _astuple(nested)
    else:
        values = nested
    mapped_values = [
        map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
    ]
    return nest._sequence_like(nested, mapped_values)
Esempio n. 8
0
 def _is_atomic_nested(nested):
     """Returns `True` if `nested` is a list representing node data."""
     if isinstance(nested, ListWrapper):
         return True
     if _is_serialized_node_data(nested):
         return True
     return not nest.is_nested(nested)
Esempio n. 9
0
def check_default_value(shape, default_value, dtype, key):
    """Returns default value as tuple if it's valid, otherwise raises errors.

  This function verifies that `default_value` is compatible with both `shape`
  and `dtype`. If it is not compatible, it raises an error. If it is compatible,
  it casts default_value to a tuple and returns it. `key` is used only
  for error message.

  Args:
    shape: An iterable of integers specifies the shape of the `Tensor`.
    default_value: If a single value is provided, the same value will be applied
      as the default value for every item. If an iterable of values is
      provided, the shape of the `default_value` should be equal to the given
      `shape`.
    dtype: defines the type of values. Default value is `tf.float32`. Must be a
      non-quantized, real integer or floating point type.
    key: Column name, used only for error messages.

  Returns:
    A tuple which will be used as default value.

  Raises:
    TypeError: if `default_value` is an iterable but not compatible with `shape`
    TypeError: if `default_value` is not compatible with `dtype`.
    ValueError: if `dtype` is not convertible to `tf.float32`.
  """
    if default_value is None:
        return None

    if isinstance(default_value, int):
        return _create_tuple(shape, default_value)

    if isinstance(default_value, float) and dtype.is_floating:
        return _create_tuple(shape, default_value)

    if callable(getattr(default_value, 'tolist',
                        None)):  # Handles numpy arrays
        default_value = default_value.tolist()

    if nest.is_nested(default_value):
        if not _is_shape_and_default_value_compatible(default_value, shape):
            raise ValueError(
                'The shape of default_value must be equal to given shape. '
                'default_value: {}, shape: {}, key: {}'.format(
                    default_value, shape, key))
        # Check if the values in the list are all integers or are convertible to
        # floats.
        is_list_all_int = all(
            isinstance(v, int) for v in nest.flatten(default_value))
        is_list_has_float = any(
            isinstance(v, float) for v in nest.flatten(default_value))
        if is_list_all_int:
            return _as_tuple(default_value)
        if is_list_has_float and dtype.is_floating:
            return _as_tuple(default_value)
    raise TypeError('default_value must be compatible with dtype. '
                    'default_value: {}, dtype: {}, key: {}'.format(
                        default_value, dtype, key))
Esempio n. 10
0
def map_to_output_names(y_pred, output_names, struct):
    """Maps a dict to a list using `output_names` as keys.

  This is a convenience feature only. When a `Model`'s outputs
  are a list, you can specify per-output losses and metrics as
  a dict, where the keys are the output names. If you specify
  per-output losses and metrics via the same structure as the
  `Model`'s outputs (recommended), no mapping is performed.

  For the Functional API, the output names are the names of the
  last layer of each output. For the Subclass API, the output names
  are determined by `create_pseudo_output_names` (For example:
  `['output_1', 'output_2']` for a list of outputs).

  This mapping preserves backwards compatibility for `compile` and
  `fit`.

  Args:
    y_pred: Sample outputs of the Model, to determine if this convenience
      feature should be applied (`struct` is returned unmodified if `y_pred`
      isn't a flat list).
    output_names: List. The names of the outputs of the Model.
    struct: The structure to map.

  Returns:
    `struct` mapped to a list in same order as `output_names`.
  """
    single_output = not nest.is_nested(y_pred)
    outputs_are_flat_list = (not single_output and isinstance(
        y_pred,
        (list, tuple)) and not any(nest.is_nested(y_p) for y_p in y_pred))

    if (single_output or outputs_are_flat_list) and isinstance(struct, dict):
        output_names = output_names or create_pseudo_output_names(y_pred)
        struct = copy.copy(struct)
        new_struct = [struct.pop(name, None) for name in output_names]
        if struct:
            raise ValueError('Found unexpected keys that do not correspond '
                             'to any Model output: {}. Expected: {}'.format(
                                 struct.keys(), output_names))
        if len(new_struct) == 1:
            return new_struct[0]
        return new_struct
    else:
        return struct
Esempio n. 11
0
 def convert_fn(path, value, dtype, dtype_hint, name=None):
     if not allow_packing and nest.is_nested(value) and any(
             # Treat arrays like Tensors for full parity in JAX backend.
             tf.is_tensor(x) or isinstance(x, np.ndarray)
             for x in nest.flatten(value)):
         raise NotImplementedError(
             ('Cannot convert a structure of tensors to a '
              'single tensor. Saw {} at path {}.').format(value, path))
     return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
Esempio n. 12
0
    def one_step(self,
                 new_chain_state,
                 current_reducer_state,
                 previous_kernel_results=None,
                 axis=None):
        """Update the `current_reducer_state` with a new chain state.

    Chunking semantics are specified by the `axis` parameter. If chunking is
    enabled (axis is not `None`), all elements along the specified `axis` will
    be treated as separate samples. If a single scalar value is provided for a
    non-scalar sample structure, that value will be used for all elements in the
    structure. If not, an identical structure must be provided.

    Args:
      new_chain_state: A (possibly nested) structure of incoming chain state(s)
        with shape and dtype compatible with those used to initialize the
        `current_reducer_state`.
      current_reducer_state: `ExpectationsReducerState` representing the current
        reducer state.
      previous_kernel_results: A (possibly nested) structure of `Tensor`s
        representing internal calculations made in a related
        `TransitionKernel`.
      axis: If chunking is desired, this is a (possibly nested) structure of
        integers that specifies the axis with chunked samples. For individual
        samples, set this to `None`. By default, samples are not chunked
        (`axis` is None).

    Returns:
      new_reducer_state: `ExpectationsReducerState` with updated running
        statistics. It tracks a running total and the number of processed
        samples.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'expectations_reducer',
                                    'one_step')):
            new_chain_state = tf.nest.map_structure(tf.convert_to_tensor,
                                                    new_chain_state)
            if previous_kernel_results is not None:
                previous_kernel_results = tf.nest.map_structure(
                    tf.convert_to_tensor,
                    previous_kernel_results,
                    expand_composites=True)
            fn_results = tf.nest.map_structure(
                lambda fn: fn(new_chain_state, previous_kernel_results),
                self.transform_fn)
            if not nest.is_nested(axis):
                axis = nest_util.broadcast_structure(fn_results, axis)

            def update(fn_results, state, axis):
                return state.update(fn_results, axis=axis)

            return ExpectationsReducerState(
                nest.map_structure(update,
                                   fn_results,
                                   current_reducer_state.expectation_state,
                                   axis,
                                   check_types=False))
Esempio n. 13
0
def _canonicalize_event_ndims(target, event_ndims):
    """Returns `event_ndims` shaped parallel to `target`, repeating as needed."""
    # This is only here to support the possibility of different event_ndims across
    # different Tensors in the target structure.  Otherwise, event_ndims could
    # just be an integer (or None) and wouldn't need to be canonicalized to a
    # structure.
    if not nest.is_nested(event_ndims):
        return nest_util.broadcast_structure(target, event_ndims)
    else:
        return event_ndims
Esempio n. 14
0
 def is_empty(x):
   """Check whether a possibly nested structure is empty."""
   if not nest.is_nested(x):
     return False
   if isinstance(x, collections.Mapping):
     return is_empty(list(x.values()))
   for item in x:
     if not is_empty(item):
       return False
   return True
Esempio n. 15
0
 def is_empty(x):
     """Check whether a possibly nested structure is empty."""
     if not nest.is_nested(x):
         return False
     if isinstance(x, collections_abc.Mapping):
         return is_empty(list(x.values()))
     for item in x:
         if not is_empty(item):
             return False
     return True
Esempio n. 16
0
def arg_is_blockwise(block_dimensions, arg, arg_split_dim):
    """Detect if input should be interpreted as a list of blocks."""
    # Tuples and lists of length equal to the number of operators may be
    # blockwise.
    if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)):
        # If the elements of the iterable are not nested, interpret the input as
        # blockwise.
        if not any(nest.is_nested(x) for x in arg):
            return True
        else:
            arg_dims = [
                ops.convert_to_tensor_v2_with_dispatch(x).shape[arg_split_dim]
                for x in arg
            ]
            self_dims = [dim.value for dim in block_dimensions]

            # If none of the operator dimensions are known, interpret the input as
            # blockwise if its matching dimensions are unequal.
            if all(self_d is None for self_d in self_dims):

                # A nested tuple/list with a single outermost element is not blockwise
                if len(arg_dims) == 1:
                    return False
                elif any(dim != arg_dims[0] for dim in arg_dims):
                    return True
                else:
                    raise ValueError(
                        "Parsing of the input structure is ambiguous. Please input "
                        "a blockwise iterable of `Tensor`s or a single `Tensor`."
                    )

            # If input dimensions equal the respective (known) blockwise operator
            # dimensions, then the input is blockwise.
            if all(self_d == arg_d or self_d is None
                   for self_d, arg_d in zip(self_dims, arg_dims)):
                return True

            # If input dimensions equals are all equal, and are greater than or equal
            # to the sum of the known operator dimensions, interpret the input as
            # blockwise.
            # input is not blockwise.
            self_dim = sum(self_d for self_d in self_dims
                           if self_d is not None)
            if all(s == arg_dims[0]
                   for s in arg_dims) and arg_dims[0] >= self_dim:
                return False

            # If none of these conditions is met, the input shape is mismatched.
            raise ValueError(
                "Input dimension does not match operator dimension.")
    else:
        return False
Esempio n. 17
0
def _scan(  # pylint: disable=unused-argument
        fn,
        elems,
        initializer=None,
        parallel_iterations=10,
        back_prop=True,
        swap_memory=False,
        infer_shape=True,
        reverse=False,
        name=None):
    """Scan implementation."""

    if reverse:
        elems = nest.map_structure(lambda x: x[::-1], elems)

    if initializer is None:
        if nest.is_nested(elems):
            raise NotImplementedError
        initializer = elems[0]
        elems = elems[1:]
        prepend = [[initializer]]
    else:
        prepend = None

    def func(arg, x):
        return nest.flatten(
            fn(nest.pack_sequence_as(initializer, arg),
               nest.pack_sequence_as(elems, x)))

    arg = nest.flatten(initializer)
    if JAX_MODE:
        from jax import lax  # pylint: disable=g-import-not-at-top

        def scan_body(arg, x):
            arg = func(arg, x)
            return arg, arg

        _, out = lax.scan(scan_body, arg, nest.flatten(elems))
    else:
        out = [[] for _ in range(len(arg))]
        for x in zip(*nest.flatten(elems)):
            arg = func(arg, x)
            for i, z in enumerate(arg):
                out[i].append(z)

    if prepend is not None:
        out = [pre + list(o) for (pre, o) in zip(prepend, out)]

    ordering = (lambda x: x[::-1]) if reverse else (lambda x: x)
    return nest.pack_sequence_as(initializer,
                                 [ordering(np.array(o)) for o in out])
Esempio n. 18
0
    def _conform_to_outputs(self, outputs, struct):
        """Convenience method to conform `struct` to `outputs` structure.

    Mappings performed:

    (1) Map a dict to a list of outputs, using the output names.
    (2) Fill missing keys in a dict w/ `None`s.
    (3) Map a single item to all outputs.

    Args:
      outputs: Model predictions.
      struct: Arbitrary nested structure (e.g. of labels, sample_weights,
        losses, or metrics).

    Returns:
      Mapping of `struct` to `outputs` structure.
    """
        struct = map_to_output_names(outputs, self._output_names, struct)
        struct = map_missing_dict_keys(outputs, struct)
        # Allow passing one object that applies to all outputs.
        if not nest.is_nested(struct) and nest.is_nested(outputs):
            struct = nest.map_structure(lambda _: struct, outputs)
        return struct
    def serialize(self, make_node_key, node_conversion_map):
        """Serializes `Node` for Functional API's `get_config`."""
        # Serialization still special-cases first argument.
        args, kwargs = self.call_args, self.call_kwargs
        inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs)

        # Treat everything other than first argument as a kwarg.
        arguments = dict(zip(self.layer._call_fn_args[1:], args))
        arguments.update(kwargs)
        kwargs = arguments

        kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
        try:
            json.dumps(kwargs, default=json_utils.get_json_type)
        except TypeError:
            kwarg_types = nest.map_structure(type, kwargs)
            raise TypeError('Layer ' + self.layer.name +
                            ' was passed non-JSON-serializable arguments. ' +
                            'Arguments had types: ' + str(kwarg_types) +
                            '. They cannot be serialized out '
                            'when saving the model.')

        # `kwargs` is added to each Tensor in the first arg. This should be
        # changed in a future version of the serialization format.
        def serialize_first_arg_tensor(t):
            if is_keras_tensor(t):
                kh = t._keras_history
                node_index = kh.node_index
                node_key = make_node_key(kh.layer.name, node_index)
                new_node_index = node_conversion_map.get(node_key, 0)
                data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
            else:
                # If an element in the first call argument did not originate as a
                # keras tensor and is a constant value, we save it using the format
                # ['_CONSTANT_VALUE', -1, serializaed_tensor_or_python_constant]
                # (potentially including serialized kwargs in an optional 4th argument
                data = [
                    _CONSTANT_VALUE, -1,
                    _serialize_keras_tensor(t), kwargs
                ]
            return tf_utils.ListWrapper(data)

        data = nest.map_structure(serialize_first_arg_tensor, inputs)
        if (not nest.is_nested(data)
                and not self.layer._preserve_input_structure_in_config):
            data = [data]
        data = tf_utils.convert_inner_node_data(data)
        return data
Esempio n. 20
0
def _is_shape_and_default_value_compatible(default_value, shape):
    """Verifies compatibility of shape and default_value."""
    # Invalid condition:
    #  * if default_value is not a scalar and shape is empty
    #  * or if default_value is an iterable and shape is not empty
    if nest.is_nested(default_value) != bool(shape):
        return False
    if not shape:
        return True
    if len(default_value) != shape[0]:
        return False
    for i in range(shape[0]):
        if not _is_shape_and_default_value_compatible(default_value[i],
                                                      shape[1:]):
            return False
    return True
Esempio n. 21
0
def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
    """Generate a zero filled tensor with shape [batch_size, state_size]."""
    if batch_size_tensor is None or dtype is None:
        raise ValueError(
            'batch_size and dtype cannot be None while constructing initial state: '
            'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))

    def create_zeros(unnested_state_size):
        flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list()
        init_state_size = [batch_size_tensor] + flat_dims
        return array_ops.zeros(init_state_size, dtype=dtype)

    if nest.is_nested(state_size):
        return nest.map_structure(create_zeros, state_size)
    else:
        return create_zeros(state_size)
Esempio n. 22
0
def replace_composites_with_components(structure):
  """Recursively replaces CompositeTensors with their components.

  Args:
    structure: A `nest`-compatible structure, possibly containing composite
      tensors.

  Returns:
    A copy of `structure`, where each composite tensor has been replaced by
    its components.  The result will contain no composite tensors.
    Note that `nest.flatten(replace_composites_with_components(structure))`
    returns the same value as `nest.flatten(structure)`.
  """
  if isinstance(structure, CompositeTensor):
    return replace_composites_with_components(
        structure._type_spec._to_components(structure))  # pylint: disable=protected-access
  elif not nest.is_nested(structure):
    return structure
  else:
    return nest.map_structure(
        replace_composites_with_components, structure, expand_composites=False)
Esempio n. 23
0
 def convert_fn(path, value, dtype, dtype_hint, name=None):
     if not allow_packing and nest.is_nested(value) and any(
             # Treat arrays like Tensors for full parity in JAX backend.
             tf.is_tensor(x) or isinstance(x, np.ndarray)
             for x in nest.flatten(value)):
         raise NotImplementedError(
             ('Cannot convert a structure of tensors to a '
              'single tensor. Saw {} at path {}.').format(value, path))
     if as_shape_tensor:
         return ps.convert_to_shape_tensor(value,
                                           dtype,
                                           dtype_hint,
                                           name=name)
     elif 'KerasTensor' in str(type(value)):
         # This is a hack to detect symbolic Keras tensors to work around
         # b/206660667.  The issue was that symbolic Keras tensors would
         # break the Bijector cache on forward/inverse log det jacobian,
         # because tf.convert_to_tensor is not a no-op thereon.
         return value
     else:
         return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
    def match(self, expected, actual):
        """Matches nested structures.

    Recursively matches shape and values of `expected` and `actual`.
    Handles scalars, numpy arrays and other python sequence containers
    e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue.

    Args:
      expected: Nested structure 1.
      actual: Nested structure 2.

    Raises:
      AssertionError if matching fails.
    """
        if isinstance(expected, np.ndarray):
            expected = expected.tolist()
        if isinstance(actual, np.ndarray):
            actual = actual.tolist()
        self.assertEqual(type(expected), type(actual))

        if nest.is_nested(expected):
            self.assertEqual(len(expected), len(actual))
            if isinstance(expected, dict):
                for key1, key2 in zip(sorted(expected), sorted(actual)):
                    self.assertEqual(key1, key2)
                    self.match(expected[key1], actual[key2])
            else:
                for item1, item2 in zip(expected, actual):
                    self.match(item1, item2)
        elif isinstance(expected, sparse_tensor.SparseTensorValue):
            self.match(
                (expected.indices, expected.values, expected.dense_shape),
                (actual.indices, actual.values, actual.dense_shape))
        elif isinstance(expected, ragged_tensor_value.RaggedTensorValue):
            self.match((expected.values, expected.row_splits),
                       (actual.values, actual.row_splits))
        else:
            self.assertEqual(expected, actual)
Esempio n. 25
0
    def _generate_initial_state(self, inputs, batch_size_tensor, state_size,
                                dtype):
        """Generate a zero filled tensor with shape [batch_size, state_size]."""
        if batch_size_tensor is None or dtype is None:
            raise ValueError(
                'batch_size and dtype cannot be None while constructing initial state: '
                'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))

        def create_init_values(unnested_state_size):
            flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list()
            init_state_size = [batch_size_tensor] + flat_dims
            if self.learned_init == 'dynamic':
                return inputs[:, 0, :] @ self.w + self.b
            elif self.learned_init == 'static':
                # Broadcast learned init vector to batch size
                return tf.broadcast_to(self.b, init_state_size)
            else:
                return array_ops.zeros(init_state_size, dtype=dtype)

        if nest.is_nested(state_size):
            return nest.map_structure(create_init_values, state_size)
        else:
            return create_init_values(state_size)
Esempio n. 26
0
 def _should_broadcast(self, obj):
     return not nest.is_nested(obj)
Esempio n. 27
0
def build_factored_surrogate_posterior(
        event_shape=None,
        constraining_bijectors=None,
        initial_unconstrained_loc=_sample_uniform_initial_loc,
        initial_unconstrained_scale=1e-2,
        trainable_distribution_fn=_build_trainable_normal_dist,
        seed=None,
        validate_args=False,
        name=None):
    """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    constraining_bijectors: Optional `tfb.Bijector` instance, or nested
      structure of such instances, defining support(s) of the posterior
      variables. The structure must match that of `event_shape` and may
      contain `None` values. A posterior variable will
      be modeled as `tfd.TransformedDistribution(underlying_dist,
      constraining_bijector)` if a corresponding constraining bijector is
      specified, otherwise it is modeled as supported on the
      unconstrained real line.
    initial_unconstrained_loc: Optional Python `callable` with signature
      `tensor = initial_unconstrained_loc(shape, seed)` used to sample
      real-valued initializations for the unconstrained representation of each
      variable. May alternately be a nested structure of
      `Tensor`s, giving specific initial locations for each variable; these
      must have structure matching `event_shape` and shapes determined by the
      inverse image of `event_shape` under `constraining_bijectors`, which
      may optionally be prefixed with a common batch shape.
      Default value: `functools.partial(tf.random.uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    initial_unconstrained_scale: Optional scalar float `Tensor` initial
      scale for the unconstrained distributions, or a nested structure of
      `Tensor` initial scales for each variable.
      Default value: `1e-2`.
    trainable_distribution_fn: Optional Python `callable` with signature
      `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale,
      event_ndims, validate_args)`. This is called for each model variable to
      build the corresponding factor in the surrogate posterior. It is expected
      that the distribution returned is supported on unconstrained real values.
      Default value: `functools.partial(
        tfp.experimental.vi.build_trainable_location_scale_distribution,
        distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution.
    seed: Python integer to seed the random number generator. This is used
      only when `initial_loc` is not specified.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').

  Returns:
    surrogate_posterior: A `tfd.Distribution` instance whose samples have
      shape and structure matching that of `event_shape` or `initial_loc`.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    constraining_bijectors=[tfb.Softplus(),   # Rate is positive.
                            tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify one when we build the surrogate posterior. This function requires the
  initial location to be specified in *unconstrained* space; we do this by
  inverting the constraining bijectors (note this section also demonstrates the
  creation of a dict-structured model).

  ```python
  initial_loc = {'concentration': 0.4, 'rate': 0.2}
  constraining_bijectors={'concentration': tfb.Softplus(),   # Rate is positive.
                          'rate': tfb.Softplus()}   # Concentration is positive.
  initial_unconstrained_loc = tf.nest.map_fn(
    lambda b, x: b.inverse(x) if b is not None else x,
    constraining_bijectors, initial_loc)
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    constraining_bijectors=constraining_bijectors,
    initial_unconstrained_loc=initial_unconstrained_state,
    initial_unconstrained_scale=1e-4)
  ```

  """

    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        seed = tfp_util.SeedStream(seed,
                                   salt='build_factored_surrogate_posterior')

        # Convert event shapes to Tensors.
        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(
            shallow_structure,
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)
        flat_event_shapes = tf.nest.flatten(event_shape)

        # For simplicity, we'll work with flattened lists of state parts and
        # repack the structure at the end.
        if constraining_bijectors is not None:
            flat_bijectors = tf.nest.flatten(constraining_bijectors)
        else:
            flat_bijectors = [None for _ in flat_event_shapes]
        flat_unconstrained_event_shapes = [
            b.inverse_event_shape_tensor(s) if b is not None else s
            for s, b in zip(flat_event_shapes, flat_bijectors)
        ]

        # Construct initial locations for the internal unconstrained dists.
        if callable(
                initial_unconstrained_loc):  # Sample random initialization.
            flat_unconstrained_locs = [
                initial_unconstrained_loc(shape=s, seed=seed())
                for s in flat_unconstrained_event_shapes
            ]
        else:  # Use provided initialization.
            flat_unconstrained_locs = nest.flatten_up_to(
                shallow_structure,
                initial_unconstrained_loc,
                check_types=False)

        if nest.is_nested(initial_unconstrained_scale):
            flat_unconstrained_scales = nest.flatten_up_to(
                shallow_structure,
                initial_unconstrained_scale,
                check_types=False)
        else:
            flat_unconstrained_scales = [
                initial_unconstrained_scale for _ in flat_unconstrained_locs
            ]

        # Extract the rank of each event, so that we build distributions with the
        # correct event shapes.
        flat_unconstrained_event_ndims = [
            prefer_static.rank_from_shape(s)
            for s in flat_unconstrained_event_shapes
        ]

        # Build the component surrogate posteriors.
        flat_component_dists = []
        for initial_loc, initial_scale, event_ndims, bijector in zip(
                flat_unconstrained_locs, flat_unconstrained_scales,
                flat_unconstrained_event_ndims, flat_bijectors):
            unconstrained_dist = trainable_distribution_fn(
                initial_loc=initial_loc,
                initial_scale=initial_scale,
                event_ndims=event_ndims,
                validate_args=validate_args)
            flat_component_dists.append(
                bijector(unconstrained_dist
                         ) if bijector is not None else unconstrained_dist)
        component_distributions = tf.nest.pack_sequence_as(
            event_shape, flat_component_dists)

        # Return a `Distribution` object whose events have the specified structure.
        return (joint_distribution_util.
                independent_joint_distribution_from_structure(
                    component_distributions, validate_args=validate_args))
Esempio n. 28
0
def _affine_surrogate_posterior(event_shape,
                                operators='diag',
                                bijector=None,
                                base_distribution=normal.Normal,
                                dtype=tf.float32,
                                batch_shape=(),
                                validate_args=False,
                                name=None):
    """Builds a joint variational posterior with a given `event_shape`.

  This function builds a surrogate posterior by applying a trainable
  transformation to a standard base distribution and constraining the samples
  with `bijector`. The surrogate posterior has event shape equal to
  the input `event_shape`.

  This function is a convenience wrapper around
  `build_affine_surrogate_posterior_from_base_distribution` that allows the
  user to pass in the desired posterior `event_shape` instead of
  pre-constructed base distributions (at the expense of full control over the
  base distribution types and parameterizations).

  Args:
    event_shape: (Nested) event shape of the posterior.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution of the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    batch_shape: Batch shape (Python tuple, list, or int) of the surrogate
      posterior, to enable parallel optimization from multiple initializations.
      Default value: `()`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_affine_surrogate_posterior').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  #### Examples

  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Define a joint probabilistic model.
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(
        tfd.Gamma(concentration=concentration, rate=rate),
        sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)

  # Assume the `y` are observed, such that the posterior is a joint distribution
  # over `concentration` and `rate`. The posterior event shape is then equal to
  # the first two components of the model's event shape.
  posterior_event_shape = model.event_shape_tensor()[:-1]

  # Constrain the posterior values to be positive using the `Exp` bijector.
  bijector = [tfb.Exp(), tfb.Exp()]

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior(
        event_shape=posterior_event_shape,
        operators='tril',
        bijector=bijector))

  # For an example defining `'operators'` as a list to express an alternative
  # covariance structure, see
  # `build_affine_surrogate_posterior_from_base_distribution`.

  # Fit the model.
  y = [0.2, 0.5, 0.3, 0.7]
  target_model = model.experimental_pin(y=y)
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```
  """
    with tf.name_scope(name or 'build_affine_surrogate_posterior'):

        event_shape = nest.map_structure_up_to(
            _get_event_shape_shallow_structure(event_shape),
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        if bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                bijector.inverse_event_shape_tensor(event_shape))

        standard_base_distribution = nest.map_structure(
            lambda s: base_distribution(loc=tf.zeros([], dtype=dtype),
                                        scale=1.), unconstrained_event_shape)
        standard_base_distribution = nest.map_structure(
            lambda d, s: (  # pylint: disable=g-long-lambda
                sample.Sample(d, sample_shape=s, validate_args=validate_args)
                if distribution_util.shape_may_be_nontrivial(s) else d),
            standard_base_distribution,
            unconstrained_event_shape)
        if distribution_util.shape_may_be_nontrivial(batch_shape):
            standard_base_distribution = nest.map_structure(
                lambda d: batch_broadcast.BatchBroadcast(  # pylint: disable=g-long-lambda
                    d,
                    to_shape=batch_shape,
                    validate_args=validate_args),
                standard_base_distribution)

        surrogate_posterior = yield from _affine_surrogate_posterior_from_base_distribution(
            standard_base_distribution,
            operators=operators,
            bijector=bijector,
            validate_args=validate_args)
        return surrogate_posterior
Esempio n. 29
0
def _affine_surrogate_posterior_from_base_distribution(
        base_distribution,
        operators='diag',
        bijector=None,
        initial_unconstrained_loc_fn=_sample_uniform_initial_loc,
        validate_args=False,
        name=None):
    """Builds a variational posterior by linearly transforming base distributions.

  This function builds a surrogate posterior by applying a trainable
  transformation to a base distribution (typically a `tfd.JointDistribution`) or
  nested structure of base distributions, and constraining the samples with
  `bijector`. Note that the distributions must have event shapes corresponding
  to the *pretransformed* surrogate posterior -- that is, if `bijector` contains
  a shape-changing bijector, then the corresponding base distribution event
  shape is the inverse event shape of the bijector applied to the desired
  surrogate posterior shape. The surrogate posterior is constucted as follows:

  1. Flatten the base distribution event shapes to vectors, and pack the base
     distributions into a `tfd.JointDistribution`.
  2. Apply a trainable blockwise LinearOperator bijector to the joint base
     distribution.
  3. Apply the constraining bijectors and return the resulting trainable
     `tfd.TransformedDistribution` instance.

  Args:
    base_distribution: `tfd.Distribution` instance (typically a
      `tfd.JointDistribution`), or a nested structure of `tfd.Distribution`
      instances.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    initial_unconstrained_loc_fn: Optional Python `callable` with signature
      `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to
      sample real-valued initializations for the unconstrained location of
      each variable.
      Default value: `functools.partial(tf.random.stateless_uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e.,
      'build_affine_surrogate_posterior_from_base_distribution').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  Raises:
    NotImplementedError: Base distributions with mixed dtypes are not supported.

  #### Examples
  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Fit a multivariate Normal surrogate posterior on the Eight Schools model
  # [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]

  def model_fn():
    avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
    log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
    school_effects = yield tfd.Sample(
        tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
        sample_shape=[8],
        name='school_effects')
    treatment_effects = yield tfd.Independent(
        tfd.Normal(loc=school_effects, scale=treatment_stddevs),
        reinterpreted_batch_ndims=1,
        name='treatment_effects')
  model = tfd.JointDistributionCoroutineAutoBatched(model_fn)

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Define a lower triangular structure of `LinearOperator` subclasses that
  # models full covariance among latent variables except for the 8 dimensions
  # of `school_effect`, which are modeled as independent (using
  # `LinearOperatorDiag`).
  operators = [
    [tf.linalg.LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,
     tf.linalg.LinearOperatorDiag]]


  # Constrain the posterior values to the support of the prior.
  bijector = target_model.experimental_default_event_space_bijector()

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
        base_distribution=base_distribution,
        operators=operators,
        bijector=bijector))

  # Fit the model.
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name
                       or 'affine_surrogate_posterior_from_base_distribution'):

        if nest.is_nested(base_distribution):
            base_distribution = (joint_distribution_util.
                                 independent_joint_distribution_from_structure(
                                     base_distribution,
                                     validate_args=validate_args))

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        batch_shape = base_distribution.batch_shape_tensor()
        if tf.nest.is_nested(
                batch_shape):  # Base is a classic JointDistribution.
            batch_shape = functools.reduce(ps.broadcast_shape,
                                           tf.nest.flatten(batch_shape))
        event_shape = base_distribution.event_shape_tensor()
        flat_event_size = nest.flatten(
            nest.map_structure(ps.reduce_prod, event_shape))

        base_dtypes = set([
            dtype_util.base_dtype(d)
            for d in nest.flatten(base_distribution.dtype)
        ])
        if len(base_dtypes) > 1:
            raise NotImplementedError(
                'Base distributions with mixed dtype are not supported. Saw '
                'components of dtype {}'.format(base_dtypes))
        base_dtype = list(base_dtypes)[0]

        num_components = len(flat_event_size)
        if operators == 'diag':
            operators = [tf.linalg.LinearOperatorDiag] * num_components
        elif operators == 'tril':
            operators = [[tf.linalg.LinearOperatorFullMatrix] * i +
                         [tf.linalg.LinearOperatorLowerTriangular]
                         for i in range(num_components)]
        elif isinstance(operators, str):
            raise ValueError(
                'Unrecognized operator type {}. Valid operators are "diag", "tril", '
                'or a structure that can be passed to '
                '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as '
                'the `operators` arg.'.format(operators))

        if nest.is_nested(operators):
            operators = yield from trainable_linear_operators._trainable_linear_operator_block(  # pylint: disable=protected-access
                operators,
                block_dims=flat_event_size,
                dtype=base_dtype,
                batch_shape=batch_shape)

        linop_bijector = (
            scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock(
                scale=operators, validate_args=validate_args))

        def generate_shift_bijector(s):
            x = yield trainable_state_util.Parameter(
                functools.partial(initial_unconstrained_loc_fn,
                                  ps.concat([batch_shape, [s]], axis=0),
                                  dtype=base_dtype))
            return shift.Shift(x)

        loc_bijectors = yield from nest_util.map_structure_coroutine(
            generate_shift_bijector, flat_event_size)
        loc_bijector = joint_map.JointMap(loc_bijectors,
                                          validate_args=validate_args)

        unflatten_and_reshape = chain.Chain([
            joint_map.JointMap(nest.map_structure(reshape.Reshape,
                                                  event_shape),
                               validate_args=validate_args),
            restructure.Restructure(
                nest.pack_sequence_as(event_shape, range(num_components)))
        ],
                                            validate_args=validate_args)

        bijectors = [] if bijector is None else [bijector]
        bijectors.extend([
            unflatten_and_reshape,
            loc_bijector,  # Allow the mean of the standard dist to shift from 0.
            linop_bijector
        ])  # Apply LinOp to scale the standard dist.
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        flat_base_distribution = invert.Invert(unflatten_and_reshape)(
            base_distribution)

        return transformed_distribution.TransformedDistribution(
            flat_base_distribution,
            bijector=bijector,
            validate_args=validate_args)
Esempio n. 30
0
def _factored_surrogate_posterior(  # pylint: disable=dangerous-default-value
        event_shape=None,
        bijector=None,
        batch_shape=(),
        base_distribution_cls=normal.Normal,
        initial_parameters={'scale': 1e-2},
        dtype=tf.float32,
        validate_args=False,
        name=None):
    """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    bijector: Optional `tfb.Bijector` instance, or nested structure of such
      instances, defining support(s) of the posterior variables. The structure
      must match that of `event_shape` and may contain `None` values. A
      posterior variable will be modeled as
      `tfd.TransformedDistribution(underlying_dist, bijector)` if a
      corresponding constraining bijector is specified, otherwise it is modeled
      as supported on the unconstrained real line.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated
      and optionally transformed by the bijector to define the component
      distributions. May optionally be a structure of such subclasses
      matching `event_shape`.
      Default value: `tfd.Normal`.
    initial_parameters: Optional `str : Tensor` dictionary specifying initial
      values for some or all of the base distribution's trainable parameters,
      or a Python `callable` with signature
      `value = parameter_init_fn(parameter_name, shape, dtype, seed,
      constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`.
      May optionally be a structure matching `event_shape` of such dictionaries
      and/or callables. Dictionary entries that do not correspond to parameter
      names are ignored.
      Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does
        not have a `scale` parameter).
    dtype: Optional float `dtype` for trainable parameters. May
      optionally be a structure of such `dtype`s matching `event_shape`.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    bijector=[tfb.Softplus(),   # Rate is positive.
              tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify initial parameters when we build the surrogate posterior. Note that
  these parameterize the distribution(s) over unconstrained values,
  so we need to transform our desired constrained locations using the inverse
  of the constraining bijector(s).

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    bijector={'concentration': tfb.Softplus(),   # Rate is positive.
              'rate': tfb.Softplus()}   # Concentration is positive.
    initial_parameters={
      'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2},
      'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}})
  ```

  """
    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        # Convert event shapes to Tensors.
        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(
            shallow_structure,
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            event_space_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity() if b is None else b,
                    nest_util.coerce_structure(event_shape, bijector)),
                validate_args=validate_args)
        else:
            event_space_bijector = bijector

        if event_space_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                event_space_bijector.inverse_event_shape_tensor(event_shape))
        unconstrained_batch_and_event_shape = tf.nest.map_structure(
            lambda s: ps.concat([batch_shape, s], axis=0),
            unconstrained_event_shape)

        base_distribution_cls = nest_util.broadcast_structure(
            event_shape, base_distribution_cls)
        try:
            # Check that we have initial parameters for each event part.
            nest.assert_shallow_structure(event_shape, initial_parameters)
        except (ValueError, TypeError):
            # If not, broadcast the parameters to match the event structure.
            # We do this manually rather than using `nest_util.broadcast_structure`
            # because the initial parameters can themselves be structures (dicts).
            initial_parameters = nest.map_structure(
                lambda x: initial_parameters, event_shape)

        unconstrained_trainable_distributions = yield from (
            nest_util.map_structure_coroutine(
                trainable._make_trainable,  # pylint: disable=protected-access
                cls=base_distribution_cls,
                initial_parameters=initial_parameters,
                batch_and_event_shape=unconstrained_batch_and_event_shape,
                parameter_dtype=nest_util.broadcast_structure(
                    event_shape, dtype),
                _up_to=event_shape))
        unconstrained_trainable_distribution = (
            joint_distribution_util.
            independent_joint_distribution_from_structure(
                unconstrained_trainable_distributions,
                batch_ndims=ps.rank_from_shape(batch_shape),
                validate_args=validate_args))
        if event_space_bijector is None:
            return unconstrained_trainable_distribution
        return transformed_distribution.TransformedDistribution(
            unconstrained_trainable_distribution, event_space_bijector)
Esempio n. 31
0
def build_split_flow_surrogate_posterior(event_shape,
                                         trainable_bijector,
                                         constraining_bijector=None,
                                         base_distribution=normal.Normal,
                                         batch_shape=(),
                                         dtype=tf.float32,
                                         validate_args=False,
                                         name=None):
    """Builds a joint variational posterior by splitting a normalizing flow.

  Args:
    event_shape: (Nested) event shape of the surrogate posterior.
    trainable_bijector: A trainable `tfb.Bijector` instance that operates on
      `Tensor`s (not structures), e.g. `tfb.MaskedAutoregressiveFlow` or
      `tfb.RealNVP`. This bijector transforms the base distribution before it is
      split.
    constraining_bijector: `tfb.Bijector` instance, or nested structure of
      `tfb.Bijector` instances, that maps (nested) values in R^n to the support
      of the posterior. (This can be the
      `experimental_default_event_space_bijector` of the distribution over the
      prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution for the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_split_flow_surrogate_posterior').

  Returns:
    surrogate_distribution: Trainable `tfd.TransformedDistribution` with event
      shape equal to `event_shape`.

  ### Examples
  ```python

  # Train a normalizing flow on the Eight Schools model [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
  model = tfd.JointDistributionNamed({
      'avg_effect':
          tfd.Normal(loc=0., scale=10., name='avg_effect'),
      'log_stddev':
          tfd.Normal(loc=5., scale=1., name='log_stddev'),
      'school_effects':
          lambda log_stddev, avg_effect: (
              tfd.Independent(
                  tfd.Normal(
                      loc=avg_effect[..., None] * tf.ones(8),
                      scale=tf.exp(log_stddev[..., None]) * tf.ones(8),
                      name='school_effects'),
                  reinterpreted_batch_ndims=1)),
      'treatment_effects': lambda school_effects: tfd.Independent(
          tfd.Normal(loc=school_effects, scale=treatment_stddevs),
          reinterpreted_batch_ndims=1)
  })

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Create a Masked Autoregressive Flow bijector.
  net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32)
  maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net)

  # Build and fit the surrogate posterior.
  surrogate_posterior = (
      tfp.experimental.vi.build_split_flow_surrogate_posterior(
          event_shape=target_model.event_shape_tensor(),
          trainable_bijector=maf,
          constraining_bijector=(
              target_model.experimental_default_event_space_bijector())))

  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name or 'build_split_flow_surrogate_posterior'):

        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(shallow_structure,
                                               ps.convert_to_shape_tensor,
                                               event_shape)

        if nest.is_nested(constraining_bijector):
            constraining_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity()
                    if b is None else b, constraining_bijector),
                validate_args=validate_args)

        if constraining_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                constraining_bijector.inverse_event_shape_tensor(event_shape))

        flat_base_event_shape = nest.flatten(unconstrained_event_shape)
        flat_base_event_size = nest.map_structure(tf.reduce_prod,
                                                  flat_base_event_shape)
        event_size = tf.reduce_sum(flat_base_event_size)

        base_distribution = sample.Sample(
            base_distribution(tf.zeros(batch_shape, dtype=dtype), scale=1.),
            [event_size])

        # After transforming base distribution samples with `trainable_bijector`,
        # split them into vector-valued components.
        split_bijector = split.Split(flat_base_event_size,
                                     validate_args=validate_args)

        # Reshape the vectors to the correct posterior event shape.
        event_reshape = joint_map.JointMap(nest.map_structure(
            reshape.Reshape, unconstrained_event_shape),
                                           validate_args=validate_args)

        # Restructure the flat list of components to the correct posterior
        # structure.
        event_unflatten = restructure.Restructure(
            nest.pack_sequence_as(unconstrained_event_shape,
                                  range(len(flat_base_event_shape))))

        bijectors = [] if constraining_bijector is None else [
            constraining_bijector
        ]
        bijectors.extend([
            event_reshape, event_unflatten, split_bijector, trainable_bijector
        ])
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        return transformed_distribution.TransformedDistribution(
            base_distribution, bijector=bijector, validate_args=validate_args)