def _make_asvi_trainable_variables(prior, mean_field=False, initial_prior_weight=0.5): """Generates parameter dictionaries given a prior distribution and list.""" with tf.name_scope('make_asvi_trainable_variables'): param_dicts = [] prior_dists = prior._get_single_sample_distributions() # pylint: disable=protected-access for dist in prior_dists: original_dist = dist.distribution if isinstance(dist, Root) else dist substituted_dist = _as_trainable_family(original_dist) # Grab the base distribution if it exists try: actual_dist = substituted_dist.distribution except AttributeError: actual_dist = substituted_dist new_params_dict = {} # Build trainable ASVI representation for each distribution's parameters. parameter_properties = actual_dist.parameter_properties( dtype=actual_dist.dtype) sample_shape = tf.concat( [dist.batch_shape_tensor(), dist.event_shape_tensor()], axis=0) for param, value in actual_dist.parameters.items(): if param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or value is None: continue try: bijector = parameter_properties[ param].default_constraining_bijector_fn() except NotImplementedError: bijector = tfb.Identity() unconstrained_ones = tf.ones( shape=bijector.inverse_event_shape_tensor( parameter_properties[param].shape_fn( sample_shape=sample_shape)), dtype=actual_dist.dtype) if mean_field: new_params_dict[param] = ASVIParameters( prior_weight=None, mean_field_parameter=tfp_util.TransformedVariable( value, bijector=bijector, name='mean_field_parameter/{}/{}'.format(dist.name, param))) else: new_params_dict[param] = ASVIParameters( prior_weight=tfp_util.TransformedVariable( initial_prior_weight * unconstrained_ones, bijector=tfb.Sigmoid(), name='prior_weight/{}/{}'.format(dist.name, param)), mean_field_parameter=tfp_util.TransformedVariable( value, bijector=bijector, name='mean_field_parameter/{}/{}'.format(dist.name, param))) param_dicts.append(new_params_dict) return param_dicts
def _build_posterior_for_one_parameter(param, batch_shape, seed): """Built a transformed-normal variational dist over a parameter's support.""" # Build a trainable Normal distribution. initial_loc = sample_uniform_initial_state(param, init_sample_shape=batch_shape, return_constrained=False, seed=seed) loc = tf.Variable(initial_value=initial_loc, name=param.name + '_loc') scale = tfp_util.TransformedVariable( tf.fill(tf.shape(initial_loc), value=tf.constant(0.02, initial_loc.dtype), name=param.name + '_scale'), softplus_lib.Softplus()) posterior_dist = normal_lib.Normal(loc=loc, scale=scale) # Ensure the `event_shape` of the variational distribution matches the # parameter. if (param.prior.event_shape.ndims is None or param.prior.event_shape.ndims > 0): posterior_dist = independent_lib.Independent( posterior_dist, reinterpreted_batch_ndims=param.prior.event_shape.ndims) # Transform to constrained parameter space. posterior_dist = transformed_distribution_lib.TransformedDistribution( posterior_dist, param.bijector, name='{}_posterior'.format(param.name)) return posterior_dist
def build_trainable_location_scale_distribution(initial_loc, initial_scale, event_ndims, distribution_fn=normal.Normal, validate_args=False, name=None): """Builds a variational distribution from a location-scale family. Args: initial_loc: Float `Tensor` initial location. initial_scale: Float `Tensor` initial scale. event_ndims: Integer `Tensor` number of event dimensions in `initial_loc`. distribution_fn: Optional constructor for a `tfd.Distribution` instance in a location-scale family. This should have signature `dist = distribution_fn(loc, scale, validate_args)`. Default value: `tfd.Normal`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_trainable_location_scale_distribution'). Returns: posterior_dist: A `tfd.Distribution` instance. """ with tf.name_scope(name or 'build_trainable_location_scale_distribution'): dtype = dtype_util.common_dtype([initial_loc, initial_scale], dtype_hint=tf.float32) initial_loc = initial_loc * tf.ones(tf.shape(initial_scale), dtype=dtype) initial_scale = initial_scale * tf.ones_like(initial_loc) loc = tf.Variable(initial_value=initial_loc, name='loc') scale = tfp_util.TransformedVariable(initial_scale, softplus.Softplus(), name='scale') posterior_dist = distribution_fn(loc=loc, scale=scale, validate_args=validate_args) # Ensure the distribution has the desired number of event dimensions. static_event_ndims = tf.get_static_value(event_ndims) if static_event_ndims is None or static_event_ndims > 0: posterior_dist = independent.Independent( posterior_dist, reinterpreted_batch_ndims=event_ndims, validate_args=validate_args) return posterior_dist
def build_trainable_linear_operator_tril(shape, scale_initializer=1e-2, diag_bijector=None, dtype=None, seed=None, name=None): """Build a trainable `LinearOperatorLowerTriangular` instance. Args: shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where `b0...bn` are batch dimensions and `d` is the length of the diagonal. scale_initializer: Variables are initialized with samples from `Normal(0, scale_initializer)`. diag_bijector: Bijector to apply to the diagonal of the operator. dtype: `tf.dtype` of the `LinearOperator`. seed: Python integer to seed the random number generator. name: str, name for `tf.name_scope`. Returns: operator: Trainable instance of `tf.linalg.LinearOperatorLowerTriangular`. """ with tf.name_scope(name or 'build_trainable_linear_operator_tril'): if dtype is None: dtype = dtype_util.common_dtype([scale_initializer], dtype_hint=tf.float32) scale_initializer = tf.convert_to_tensor(scale_initializer, dtype=dtype) diag_bijector = diag_bijector or _DefaultScaleDiagonal() batch_shape, dim = ps.split(shape, num_or_size_splits=[-1, 1]) scale_tril_bijector = fill_scale_tril.FillScaleTriL( diag_bijector, diag_shift=tf.zeros([], dtype=dtype)) flat_initial_scale = samplers.normal( mean=0., stddev=scale_initializer, shape=ps.concat([batch_shape, dim * (dim + 1) // 2], axis=0), seed=seed, dtype=dtype) return tf.linalg.LinearOperatorLowerTriangular( tril=tfp_util.TransformedVariable( scale_tril_bijector.forward(flat_initial_scale), bijector=scale_tril_bijector, name='tril'), is_non_singular=True)
def build_trainable_linear_operator_diag(shape, scale_initializer=1e-2, diag_bijector=None, dtype=None, seed=None, name=None): """Build a trainable `LinearOperatorDiag` instance. Args: shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where `b0...bn` are batch dimensions and `d` is the length of the diagonal. scale_initializer: Variables are initialized with samples from `Normal(0, scale_initializer)`. diag_bijector: Bijector to apply to the diagonal of the operator. dtype: `tf.dtype` of the `LinearOperator`. seed: Python integer to seed the random number generator. name: str, name for `tf.name_scope`. Returns: operator: Trainable instance of `tf.linalg.LinearOperatorDiag`. """ with tf.name_scope(name or 'build_trainable_linear_operator_diag'): if dtype is None: dtype = dtype_util.common_dtype([scale_initializer], dtype_hint=tf.float32) scale_initializer = tf.convert_to_tensor(scale_initializer, dtype=dtype) diag_bijector = diag_bijector or _DefaultScaleDiagonal() initial_scale_diag = samplers.normal(mean=0., stddev=scale_initializer, shape=shape, dtype=dtype, seed=seed) return tf.linalg.LinearOperatorDiag(tfp_util.TransformedVariable( diag_bijector.forward(initial_scale_diag), bijector=diag_bijector, name='diag'), is_non_singular=True)
def _make_asvi_trainable_variables(prior, mean_field=False, initial_prior_weight=0.5): """Generates parameter dictionaries given a prior distribution and list.""" with tf.name_scope('make_asvi_trainable_variables'): param_dicts = [] prior_dists = prior._get_single_sample_distributions() # pylint: disable=protected-access for dist in prior_dists: original_dist = dist.distribution if isinstance(dist, Root) else dist substituted_dist = _as_trainable_family(original_dist) # Grab the base distribution if it exists try: actual_dist = substituted_dist.distribution except AttributeError: actual_dist = substituted_dist new_params_dict = {} # Build trainable ASVI representation for each distribution's parameters. parameter_properties = actual_dist.parameter_properties( dtype=actual_dist.dtype) if isinstance(original_dist, sample.Sample): posterior_batch_shape = ps.concat([ actual_dist.batch_shape_tensor(), distribution_util.expand_to_vector(original_dist.sample_shape) ], axis=0) else: posterior_batch_shape = actual_dist.batch_shape_tensor() for param, value in actual_dist.parameters.items(): if param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or value is None: continue actual_event_shape = parameter_properties[param].shape_fn( actual_dist.event_shape_tensor()) try: bijector = parameter_properties[ param].default_constraining_bijector_fn() except NotImplementedError: bijector = identity.Identity() if mean_field: prior_weight = None else: unconstrained_ones = tf.ones( shape=ps.concat([ posterior_batch_shape, bijector.inverse_event_shape_tensor( actual_event_shape) ], axis=0), dtype=tf.convert_to_tensor(value).dtype) prior_weight = tfp_util.TransformedVariable( initial_prior_weight * unconstrained_ones, bijector=sigmoid.Sigmoid(), name='prior_weight/{}/{}'.format(dist.name, param)) # If the prior distribution was a tfd.Sample wrapping a base # distribution, we want to give every single sample in the prior its # own lambda and alpha value (rather than having a single lambda and # alpha). if isinstance(original_dist, sample.Sample): value = tf.reshape( value, ps.concat([ actual_dist.batch_shape_tensor(), ps.ones(ps.rank_from_shape(original_dist.sample_shape)), actual_event_shape ], axis=0)) value = tf.broadcast_to( value, ps.concat([posterior_batch_shape, actual_event_shape], axis=0)) new_params_dict[param] = ASVIParameters( prior_weight=prior_weight, mean_field_parameter=tfp_util.TransformedVariable( value, bijector=bijector, name='mean_field_parameter/{}/{}'.format(dist.name, param))) param_dicts.append(new_params_dict) return param_dicts
def _asvi_convex_update_for_base_distribution(dist, mean_field, initial_prior_weight, sample_shape=None, variables=None, seed=None): """Creates a trainable surrogate for a (non-meta, non-joint) distribution.""" if variables is None: variables = {} posterior_batch_shape = dist.batch_shape_tensor() if sample_shape is not None: posterior_batch_shape = ps.concat([ posterior_batch_shape, distribution_util.expand_to_vector(sample_shape) ], axis=0) # Create variables backing each parameter, if needed. all_parameter_properties = dist.parameter_properties(dtype=dist.dtype) for param, prior_value in dist.parameters.items(): if (param in variables or param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or prior_value is None): continue param_properties = all_parameter_properties[param] try: bijector = param_properties.default_constraining_bijector_fn() except NotImplementedError: bijector = identity.Identity() param_shape = ps.concat([ posterior_batch_shape, ps.shape(prior_value)[ps.rank(prior_value) - param_properties.event_ndims:] ], axis=0) prior_weight = ( None if mean_field # pylint: disable=g-long-ternary else tfp_util.TransformedVariable(initial_value=tf.fill( dims=param_shape, value=tf.cast(initial_prior_weight, tf.convert_to_tensor(prior_value).dtype)), bijector=sigmoid.Sigmoid(), name='prior_weight/{}/{}'.format( _get_name(dist), param))) # Initialize the mean-field parameter as a (constrained) standard # normal sample. seed, param_seed = samplers.split_seed(seed) variables[param] = ASVIParameters( prior_weight=prior_weight, mean_field_parameter=tfp_util.TransformedVariable( initial_value=bijector.forward( samplers.normal( shape=bijector.inverse_event_shape(param_shape), seed=param_seed)), bijector=bijector, name='mean_field_parameter/{}/{}'.format( _get_name(dist), param))) temp_params_dict = {'name': _get_name(dist)} for param, prior_value in dist.parameters.items(): if param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or prior_value is None: temp_params_dict[param] = prior_value else: if mean_field: temp_params_dict[param] = variables[param].mean_field_parameter else: temp_params_dict[param] = ( variables[param].prior_weight * prior_value + ((1. - variables[param].prior_weight) * variables[param].mean_field_parameter)) return type(dist)(**temp_params_dict), variables
def build_trainable_highway_flow(width, residual_fraction_initial_value=0.5, activation_fn=None, gate_first_n=None, seed=None, validate_args=False): """Builds a HighwayFlow parameterized by trainable variables. The variables are transformed to enforce the following parameter constraints: - `residual_fraction` is bounded between 0 and 1. - `upper_diagonal_weights_matrix` is a randomly initialized (lower) diagonal matrix with positive diagonal of size `width x width`. - `lower_diagonal_weights_matrix` is a randomly initialized lower diagonal matrix with ones on the diagonal of size `width x width`; - `bias` is a randomly initialized vector of size `width`. Args: width: Input dimension of the bijector. residual_fraction_initial_value: Initial value for gating parameter, must be between 0 and 1. activation_fn: Callable invertible activation function (e.g., `tf.nn.softplus`), or `None`. gate_first_n: Decides which part of the input should be gated (useful for example when using auxiliary variables). seed: Seed for random initialization of the weights. validate_args: Python `bool`. Whether to validate input with runtime assertions. Default value: `False`. Returns: trainable_highway_flow: The initialized bijector. """ residual_fraction_initial_value = tf.convert_to_tensor( residual_fraction_initial_value, dtype_hint=tf.float32, name='residual_fraction_initial_value') dtype = residual_fraction_initial_value.dtype bias_seed, upper_seed, lower_seed = samplers.split_seed(seed, n=3) lower_bijector = tfb.Chain([ tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)), tfb.Pad(paddings=[(1, 0), (0, 1)]), tfb.FillTriangular() ]) unconstrained_lower_initial_values = samplers.normal( shape=lower_bijector.inverse_event_shape([width, width]), mean=0., stddev=.01, seed=lower_seed) upper_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus(), diag_shift=None) unconstrained_upper_initial_values = samplers.normal( shape=upper_bijector.inverse_event_shape([width, width]), mean=0., stddev=.01, seed=upper_seed) return HighwayFlow(residual_fraction=util.TransformedVariable( initial_value=residual_fraction_initial_value, bijector=tfb.Sigmoid(), dtype=dtype), activation_fn=activation_fn, bias=tf.Variable(samplers.normal((width, ), mean=0., stddev=0.01, seed=bias_seed), dtype=dtype), upper_diagonal_weights_matrix=util.TransformedVariable( initial_value=upper_bijector.forward( unconstrained_upper_initial_values), bijector=upper_bijector, dtype=dtype), lower_diagonal_weights_matrix=util.TransformedVariable( initial_value=lower_bijector.forward( unconstrained_lower_initial_values), bijector=lower_bijector, dtype=dtype), gate_first_n=gate_first_n, validate_args=validate_args)