def testGetAndMakeFromParameters(self): one = tf.constant(1.0) d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True) d = tfp.bijectors.Tanh()(d) d = tfp.bijectors.Tanh()(d) p = utils.get_parameters(d) expected_p = utils.Params( tfp.distributions.TransformedDistribution, params={ 'bijector': utils.Params(tfp.bijectors.Chain, params={ 'bijectors': [ utils.Params(tfp.bijectors.Tanh, params={}), utils.Params(tfp.bijectors.Tanh, params={}), ] }), 'distribution': utils.Params(tfp.distributions.Normal, params={ 'validate_args': True, 'scale': 3.0, 'loc': one }) }) self.compare_params(p, expected_p) d_recreated = utils.make_from_parameters(p) points = [0.01, 0.25, 0.5, 0.75, 0.99] self.assertAllClose(d.log_prob(points), d_recreated.log_prob(points))
def testParametersFromDictMissingNestedDictKeyFailure(self): one = tf.constant(1.0) d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True) d = tfp.bijectors.Tanh()(d) d = tfp.bijectors.Tanh()(d) p = utils.get_parameters(d) p_dict = utils.parameters_to_dict(p) # Remove a non-nested key in the dictionary; this is fine. del p_dict['distribution']['validate_args'] # We can reconstruct from this (we just use the default value from p) utils.merge_to_parameters_from_dict(p, p_dict) # Remove a nested entry in the dictionary; this can lead to subtle errors so # we don't allow it. del p_dict['bijector']['bijectors:1'] # Flattening nested params lost information about the nested structure, so # we can't e.g. remove a bijector from a list and override just a subset of # the nested bijectors list. with self.assertRaisesRegex( KeyError, r'Only saw partial information from the dictionary for nested key ' r'\'bijectors\' in params_dict.*' r'Entries provided: \[\'bijectors:0\'\].*' r'Entries required: \[\'bijectors:0\', \'bijectors:1\'\]'): utils.merge_to_parameters_from_dict(p, p_dict)
def testParametersToAndFromDict(self, tensors_only): scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]]) d = tfp.distributions.MultivariateNormalDiag(loc=[1.0, 1.0], scale_diag=[2.0, 3.0], validate_args=True) b = tfp.bijectors.ScaleMatvecLinearOperator( scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix), adjoint=True) b_d = b(d) p = utils.get_parameters(b_d) p_dict = utils.parameters_to_dict(p, tensors_only=tensors_only) if tensors_only: expected_p_dict = { 'bijector': { 'scale': { 'matrix': scale_matrix } }, 'distribution': {}, } else: expected_p_dict = { 'bijector': { 'adjoint': True, 'scale': { 'matrix': scale_matrix } }, 'distribution': { 'validate_args': True, # These are deeply nested because we passed lists # intead of numpy arrays for `loc` and `scale_diag`. 'scale_diag:0': 2.0, 'scale_diag:1': 3.0, 'loc:0': 1.0, 'loc:1': 1.0 } } tf.nest.map_structure(self.assertAllEqual, p_dict, expected_p_dict) # This converts the tf.Variable entry in the matrix to a tf.Tensor p_dict['bijector']['scale']['matrix'] = ( p_dict['bijector']['scale']['matrix'] + 1.0) # When tensors_only=True, we make sure that we can merge into p # from a dict where we dropped everything but tensors. p_recreated = utils.merge_to_parameters_from_dict(p, p_dict) self.assertAllClose( p_recreated.params['bijector'].params['scale'].params['matrix'], p.params['bijector'].params['scale'].params['matrix'] + 1.0) # Skip the tensor value comparison -- we checked it above. self.compare_params(p, p_recreated, skip_tensor_values=True)
def _calc_unbatched_spec(x): if isinstance(x, tfp.distributions.Distribution): parameters = distribution_utils.get_parameters(x) parameter_specs = _convert_to_spec_and_remove_singleton_batch_dim( parameters, outer_ndim=outer_ndim) return distribution_utils.DistributionSpecV2( event_shape=x.event_shape, dtype=x.dtype, parameters=parameter_specs) else: return nest_utils.remove_singleton_batch_spec_dim( tf.type_spec_from_value(x), outer_ndim=outer_ndim)
def _calc_unbatched_spec(x): """Build Network output spec by removing previously added batch dimension. Args: x: tfp.distributions.Distribution or Tensor. Returns: Specs without batch dimension representing x. """ if isinstance(x, tfp.distributions.Distribution): parameters = distribution_utils.get_parameters(x) parameter_specs = _convert_to_spec_and_remove_singleton_batch_dim( parameters, outer_ndim=1) return distribution_utils.DistributionSpecV2( event_shape=x.event_shape, dtype=x.dtype, parameters=parameter_specs) else: return tensor_spec.remove_outer_dims_nest( tf.type_spec_from_value(x), num_outer_dims=1)
def testGetAndMakeNontrivialBijectorFromParameters(self): scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]]) d = tfp.distributions.MultivariateNormalDiag(loc=[1.0, 1.0], scale_diag=[2.0, 3.0], validate_args=True) b = tfp.bijectors.ScaleMatvecLinearOperator( scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix), adjoint=True) b_d = b(d) p = utils.get_parameters(b_d) expected_p = utils.Params( tfp.distributions.TransformedDistribution, params={ 'bijector': utils.Params(tfp.bijectors.ScaleMatvecLinearOperator, params={ 'adjoint': True, 'scale': utils.Params( tf.linalg.LinearOperatorFullMatrix, params={'matrix': scale_matrix}) }), 'distribution': utils.Params(tfp.distributions.MultivariateNormalDiag, params={ 'validate_args': True, 'scale_diag': [2.0, 3.0], 'loc': [1.0, 1.0] }) }) self.compare_params(p, expected_p) b_d_recreated = utils.make_from_parameters(p) points = [[-1.0, -2.0], [0.0, 0.0], [3.0, -5.0], [5.0, 5.0], [1.0, np.inf], [-np.inf, 0.0]] self.assertAllClose(b_d.log_prob(points), b_d_recreated.log_prob(points))
def testParametersFromDictMissingNestedParamsKeyFailure(self): one = tf.constant(1.0) d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True) d = tfp.bijectors.Tanh()(d) d = tfp.bijectors.Tanh()(d) p = utils.get_parameters(d) p_dict = utils.parameters_to_dict(p) # Add a third bijector, changing the structure of the nest. self.assertIn('bijectors:0', p_dict['bijector'].keys()) self.assertIn('bijectors:1', p_dict['bijector'].keys()) p_dict['bijector']['bijectors:2'] = p_dict['bijector']['bijectors:0'] # Flattening nested params lost information about the nested structure, so # we can't e.g. add a new bijector in the dict and expect to put that back # into the bijector list when converting back. with self.assertRaisesRegex( ValueError, r'params_dict keys: \[\'bijectors:0\', \'bijectors:1\', ' r'\'bijectors:2\'\], value.params processed keys: ' r'\[\'bijectors:0\', \'bijectors:1\'\]'): utils.merge_to_parameters_from_dict(p, p_dict)
def dist_params_dict(d): return distribution_utils.parameters_to_dict( distribution_utils.get_parameters(d), tensors_only=True)
def dist_params_dict(d): return distribution_utils.parameters_to_dict( distribution_utils.get_parameters(d))