def step(bij, x, x_event_ndims, increased_dof, **kwargs): # pylint: disable=missing-docstring # Transform inputs for the next bijector. y = forward_fn(bij, x, **kwargs) if compute_x_values else None y_event_ndims = forward_event_ndims_fn(bij, x_event_ndims, **kwargs) # Check if the inputs to this bijector have increased degrees of freedom # due to some upstream bijector. We assume that the upstream bijector # produced a valid LDJ, but this one does not (unless LDJ is 0, in which # case it doesn't matter). increased_dof = ps.reduce_any(nest.flatten(increased_dof)) if compute_x_values and self.validate_event_size: assertions = [ self._maybe_warn_increased_dof( component_name=bij.name, increased_dof=increased_dof) ] increased_dof |= (_event_size(y, y_event_ndims) > _event_size(x, x_event_ndims)) else: assertions = [] y = nest_util.broadcast_structure(y_event_ndims, y) increased_dof = nest_util.broadcast_structure(y_event_ndims, increased_dof) bijectors_with_metadata.append( BijectorWithMetadata( bijector=bij, x=x, x_event_ndims=x_event_ndims, kwargs=kwargs, assertions=assertions, )) return y, y_event_ndims, increased_dof
def step(bij, y, y_event_ndims, increased_dof=False, **kwargs): # pylint: disable=missing-docstring nonlocal ldj_sum # Compute the LDJ for this step, and add it to the rolling sum. component_ldj = tf.convert_to_tensor( bij.inverse_log_det_jacobian(y, y_event_ndims, **kwargs), dtype_hint=ldj_sum.dtype) if not dtype_util.is_floating(component_ldj.dtype): raise TypeError(('Nested bijector "{}" of Composition "{}" returned ' 'LDJ with a non-floating dtype: {}') .format(bij.name, self.name, component_ldj.dtype)) ldj_sum = _max_precision_sum(ldj_sum, component_ldj) # Transform inputs for the next bijector. x = bij.inverse(y, **kwargs) x_event_ndims = bij.inverse_event_ndims(y_event_ndims, **kwargs) # Check if the inputs to this bijector have increased degrees of freedom # due to some upstream bijector. We assume that the upstream bijector # produced a valid LDJ, but this one does not (unless LDJ is 0, in which # case it doesn't matter). increased_dof = ps.reduce_any(nest.flatten(increased_dof)) if self.validate_event_size: assertions.append(self._maybe_warn_increased_dof( component_name=bij.name, component_ldj=component_ldj, increased_dof=increased_dof)) increased_dof |= (_event_size(x, x_event_ndims) > _event_size(y, y_event_ndims)) increased_dof = nest_util.broadcast_structure(x, increased_dof) return x, x_event_ndims, increased_dof
def rank_to_has_batch_dimensions(cls, rank: TensorflowTreeTopology): event_ndims = cls.get_event_ndims() batch_ndims = tf.nest.map_structure( lambda elem_rank, elem_event_ndims: elem_rank - elem_event_ndims, rank, event_ndims, ) has_batch_dims_array = ps.stack(tf.nest.flatten(batch_ndims)) > 0 return ps.reduce_any(has_batch_dims_array)
def _update_inv_hessian(prev_state, next_state): """Update the BGFS state by computing the next inverse hessian estimate.""" # Only update the inverse Hessian if not already failed or converged. should_update = ~next_state.converged & ~next_state.failed # Compute the normalization term (y^T . s), should not update if is singular. gradient_delta = next_state.objective_gradient - prev_state.objective_gradient position_delta = next_state.position - prev_state.position normalization_factor = tf.reduce_sum(gradient_delta * position_delta, axis=-1) should_update = should_update & ~tf.equal(normalization_factor, 0) def _do_update_inv_hessian(): next_inv_hessian = _bfgs_inv_hessian_update( gradient_delta, position_delta, normalization_factor, prev_state.inverse_hessian_estimate) return bfgs_utils.update_fields( next_state, inverse_hessian_estimate=tf.where( should_update[..., tf.newaxis, tf.newaxis], next_inv_hessian, prev_state.inverse_hessian_estimate)) return ps.cond(ps.reduce_any(should_update), _do_update_inv_hessian, lambda: next_state)