def testParametersFromDictMissingNestedDictKeyFailure(self): one = tf.constant(1.0) d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True) d = tfp.bijectors.Tanh()(d) d = tfp.bijectors.Tanh()(d) p = utils.get_parameters(d) p_dict = utils.parameters_to_dict(p) # Remove a non-nested key in the dictionary; this is fine. del p_dict['distribution']['validate_args'] # We can reconstruct from this (we just use the default value from p) utils.merge_to_parameters_from_dict(p, p_dict) # Remove a nested entry in the dictionary; this can lead to subtle errors so # we don't allow it. del p_dict['bijector']['bijectors:1'] # Flattening nested params lost information about the nested structure, so # we can't e.g. remove a bijector from a list and override just a subset of # the nested bijectors list. with self.assertRaisesRegex( KeyError, r'Only saw partial information from the dictionary for nested key ' r'\'bijectors\' in params_dict.*' r'Entries provided: \[\'bijectors:0\'\].*' r'Entries required: \[\'bijectors:0\', \'bijectors:1\'\]'): utils.merge_to_parameters_from_dict(p, p_dict)
def nested_dist_params(spec): if not isinstance(spec, distribution_utils.DistributionSpecV2): raise ValueError( 'Unexpected output from `actor_network`. Expected ' '`Distribution` objects, but saw output spec: {}' .format(actor_output_spec)) return distribution_utils.parameters_to_dict( spec.parameters, tensors_only=True)
def testParametersToAndFromDict(self, tensors_only): scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]]) d = tfp.distributions.MultivariateNormalDiag(loc=[1.0, 1.0], scale_diag=[2.0, 3.0], validate_args=True) b = tfp.bijectors.ScaleMatvecLinearOperator( scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix), adjoint=True) b_d = b(d) p = utils.get_parameters(b_d) p_dict = utils.parameters_to_dict(p, tensors_only=tensors_only) if tensors_only: expected_p_dict = { 'bijector': { 'scale': { 'matrix': scale_matrix } }, 'distribution': {}, } else: expected_p_dict = { 'bijector': { 'adjoint': True, 'scale': { 'matrix': scale_matrix } }, 'distribution': { 'validate_args': True, # These are deeply nested because we passed lists # intead of numpy arrays for `loc` and `scale_diag`. 'scale_diag:0': 2.0, 'scale_diag:1': 3.0, 'loc:0': 1.0, 'loc:1': 1.0 } } tf.nest.map_structure(self.assertAllEqual, p_dict, expected_p_dict) # This converts the tf.Variable entry in the matrix to a tf.Tensor p_dict['bijector']['scale']['matrix'] = ( p_dict['bijector']['scale']['matrix'] + 1.0) # When tensors_only=True, we make sure that we can merge into p # from a dict where we dropped everything but tensors. p_recreated = utils.merge_to_parameters_from_dict(p, p_dict) self.assertAllClose( p_recreated.params['bijector'].params['scale'].params['matrix'], p.params['bijector'].params['scale'].params['matrix'] + 1.0) # Skip the tensor value comparison -- we checked it above. self.compare_params(p, p_recreated, skip_tensor_values=True)
def testParametersFromDictMissingNestedParamsKeyFailure(self): one = tf.constant(1.0) d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True) d = tfp.bijectors.Tanh()(d) d = tfp.bijectors.Tanh()(d) p = utils.get_parameters(d) p_dict = utils.parameters_to_dict(p) # Add a third bijector, changing the structure of the nest. self.assertIn('bijectors:0', p_dict['bijector'].keys()) self.assertIn('bijectors:1', p_dict['bijector'].keys()) p_dict['bijector']['bijectors:2'] = p_dict['bijector']['bijectors:0'] # Flattening nested params lost information about the nested structure, so # we can't e.g. add a new bijector in the dict and expect to put that back # into the bijector list when converting back. with self.assertRaisesRegex( ValueError, r'params_dict keys: \[\'bijectors:0\', \'bijectors:1\', ' r'\'bijectors:2\'\], value.params processed keys: ' r'\[\'bijectors:0\', \'bijectors:1\'\]'): utils.merge_to_parameters_from_dict(p, p_dict)
def dist_params_dict(d): return distribution_utils.parameters_to_dict( distribution_utils.get_parameters(d), tensors_only=True)
def dist_params_dict(d): return distribution_utils.parameters_to_dict( distribution_utils.get_parameters(d))