Ejemplo n.º 1
0
    def step(bij, x, x_event_ndims, increased_dof, **kwargs):  # pylint: disable=missing-docstring
      # Transform inputs for the next bijector.
      y = forward_fn(bij, x, **kwargs) if compute_x_values else None
      y_event_ndims = forward_event_ndims_fn(bij, x_event_ndims, **kwargs)

      # Check if the inputs to this bijector have increased degrees of freedom
      # due to some upstream bijector. We assume that the upstream bijector
      # produced a valid LDJ, but this one does not (unless LDJ is 0, in which
      # case it doesn't matter).
      increased_dof = ps.reduce_any(nest.flatten(increased_dof))
      if compute_x_values and self.validate_event_size:
        assertions = [
            self._maybe_warn_increased_dof(
                component_name=bij.name, increased_dof=increased_dof)
        ]
        increased_dof |= (_event_size(y, y_event_ndims)
                          > _event_size(x, x_event_ndims))
      else:
        assertions = []

      y = nest_util.broadcast_structure(y_event_ndims, y)
      increased_dof = nest_util.broadcast_structure(y_event_ndims,
                                                    increased_dof)
      bijectors_with_metadata.append(
          BijectorWithMetadata(
              bijector=bij,
              x=x,
              x_event_ndims=x_event_ndims,
              kwargs=kwargs,
              assertions=assertions,
          ))
      return y, y_event_ndims, increased_dof
Ejemplo n.º 2
0
    def step(bij, y, y_event_ndims, increased_dof=False, **kwargs):  # pylint: disable=missing-docstring
      nonlocal ldj_sum

      # Compute the LDJ for this step, and add it to the rolling sum.
      component_ldj = tf.convert_to_tensor(
          bij.inverse_log_det_jacobian(y, y_event_ndims, **kwargs),
          dtype_hint=ldj_sum.dtype)

      if not dtype_util.is_floating(component_ldj.dtype):
        raise TypeError(('Nested bijector "{}" of Composition "{}" returned '
                         'LDJ with a non-floating dtype: {}')
                        .format(bij.name, self.name, component_ldj.dtype))
      ldj_sum = _max_precision_sum(ldj_sum, component_ldj)

      # Transform inputs for the next bijector.
      x = bij.inverse(y, **kwargs)
      x_event_ndims = bij.inverse_event_ndims(y_event_ndims, **kwargs)

      # Check if the inputs to this bijector have increased degrees of freedom
      # due to some upstream bijector. We assume that the upstream bijector
      # produced a valid LDJ, but this one does not (unless LDJ is 0, in which
      # case it doesn't matter).
      increased_dof = ps.reduce_any(nest.flatten(increased_dof))
      if self.validate_event_size:
        assertions.append(self._maybe_warn_increased_dof(
            component_name=bij.name,
            component_ldj=component_ldj,
            increased_dof=increased_dof))
        increased_dof |= (_event_size(x, x_event_ndims)
                          > _event_size(y, y_event_ndims))

      increased_dof = nest_util.broadcast_structure(x, increased_dof)
      return x, x_event_ndims, increased_dof
Ejemplo n.º 3
0
 def rank_to_has_batch_dimensions(cls, rank: TensorflowTreeTopology):
     event_ndims = cls.get_event_ndims()
     batch_ndims = tf.nest.map_structure(
         lambda elem_rank, elem_event_ndims: elem_rank - elem_event_ndims,
         rank,
         event_ndims,
     )
     has_batch_dims_array = ps.stack(tf.nest.flatten(batch_ndims)) > 0
     return ps.reduce_any(has_batch_dims_array)
Ejemplo n.º 4
0
def _update_inv_hessian(prev_state, next_state):
    """Update the BGFS state by computing the next inverse hessian estimate."""
    # Only update the inverse Hessian if not already failed or converged.
    should_update = ~next_state.converged & ~next_state.failed

    # Compute the normalization term (y^T . s), should not update if is singular.
    gradient_delta = next_state.objective_gradient - prev_state.objective_gradient
    position_delta = next_state.position - prev_state.position
    normalization_factor = tf.reduce_sum(gradient_delta * position_delta,
                                         axis=-1)
    should_update = should_update & ~tf.equal(normalization_factor, 0)

    def _do_update_inv_hessian():
        next_inv_hessian = _bfgs_inv_hessian_update(
            gradient_delta, position_delta, normalization_factor,
            prev_state.inverse_hessian_estimate)
        return bfgs_utils.update_fields(
            next_state,
            inverse_hessian_estimate=tf.where(
                should_update[..., tf.newaxis, tf.newaxis], next_inv_hessian,
                prev_state.inverse_hessian_estimate))

    return ps.cond(ps.reduce_any(should_update), _do_update_inv_hessian,
                   lambda: next_state)