コード例 #1
0
ファイル: utils_test.py プロジェクト: tensorflow/agents
    def testParametersFromDictMissingNestedDictKeyFailure(self):
        one = tf.constant(1.0)
        d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True)
        d = tfp.bijectors.Tanh()(d)
        d = tfp.bijectors.Tanh()(d)
        p = utils.get_parameters(d)
        p_dict = utils.parameters_to_dict(p)

        # Remove a non-nested key in the dictionary; this is fine.
        del p_dict['distribution']['validate_args']

        # We can reconstruct from this (we just use the default value from p)
        utils.merge_to_parameters_from_dict(p, p_dict)

        # Remove a nested entry in the dictionary; this can lead to subtle errors so
        # we don't allow it.
        del p_dict['bijector']['bijectors:1']

        # Flattening nested params lost information about the nested structure, so
        # we can't e.g. remove a bijector from a list and override just a subset of
        # the nested bijectors list.
        with self.assertRaisesRegex(
                KeyError,
                r'Only saw partial information from the dictionary for nested key '
                r'\'bijectors\' in params_dict.*'
                r'Entries provided: \[\'bijectors:0\'\].*'
                r'Entries required: \[\'bijectors:0\', \'bijectors:1\'\]'):
            utils.merge_to_parameters_from_dict(p, p_dict)
コード例 #2
0
 def nested_dist_params(spec):
   if not isinstance(spec, distribution_utils.DistributionSpecV2):
     raise ValueError(
         'Unexpected output from `actor_network`.  Expected '
         '`Distribution` objects, but saw output spec: {}'
         .format(actor_output_spec))
   return distribution_utils.parameters_to_dict(
       spec.parameters, tensors_only=True)
コード例 #3
0
ファイル: utils_test.py プロジェクト: tensorflow/agents
    def testParametersToAndFromDict(self, tensors_only):
        scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]])
        d = tfp.distributions.MultivariateNormalDiag(loc=[1.0, 1.0],
                                                     scale_diag=[2.0, 3.0],
                                                     validate_args=True)
        b = tfp.bijectors.ScaleMatvecLinearOperator(
            scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix),
            adjoint=True)
        b_d = b(d)
        p = utils.get_parameters(b_d)

        p_dict = utils.parameters_to_dict(p, tensors_only=tensors_only)

        if tensors_only:
            expected_p_dict = {
                'bijector': {
                    'scale': {
                        'matrix': scale_matrix
                    }
                },
                'distribution': {},
            }
        else:
            expected_p_dict = {
                'bijector': {
                    'adjoint': True,
                    'scale': {
                        'matrix': scale_matrix
                    }
                },
                'distribution': {
                    'validate_args': True,
                    # These are deeply nested because we passed lists
                    # intead of numpy arrays for `loc` and `scale_diag`.
                    'scale_diag:0': 2.0,
                    'scale_diag:1': 3.0,
                    'loc:0': 1.0,
                    'loc:1': 1.0
                }
            }

        tf.nest.map_structure(self.assertAllEqual, p_dict, expected_p_dict)

        # This converts the tf.Variable entry in the matrix to a tf.Tensor
        p_dict['bijector']['scale']['matrix'] = (
            p_dict['bijector']['scale']['matrix'] + 1.0)

        # When tensors_only=True, we make sure that we can merge into p
        # from a dict where we dropped everything but tensors.
        p_recreated = utils.merge_to_parameters_from_dict(p, p_dict)

        self.assertAllClose(
            p_recreated.params['bijector'].params['scale'].params['matrix'],
            p.params['bijector'].params['scale'].params['matrix'] + 1.0)

        # Skip the tensor value comparison -- we checked it above.
        self.compare_params(p, p_recreated, skip_tensor_values=True)
コード例 #4
0
    def testParametersFromDictMissingNestedParamsKeyFailure(self):
        one = tf.constant(1.0)
        d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True)
        d = tfp.bijectors.Tanh()(d)
        d = tfp.bijectors.Tanh()(d)
        p = utils.get_parameters(d)
        p_dict = utils.parameters_to_dict(p)

        # Add a third bijector, changing the structure of the nest.
        self.assertIn('bijectors:0', p_dict['bijector'].keys())
        self.assertIn('bijectors:1', p_dict['bijector'].keys())
        p_dict['bijector']['bijectors:2'] = p_dict['bijector']['bijectors:0']

        # Flattening nested params lost information about the nested structure, so
        # we can't e.g. add a new bijector in the dict and expect to put that back
        # into the bijector list when converting back.
        with self.assertRaisesRegex(
                ValueError,
                r'params_dict keys: \[\'bijectors:0\', \'bijectors:1\', '
                r'\'bijectors:2\'\], value.params processed keys: '
                r'\[\'bijectors:0\', \'bijectors:1\'\]'):
            utils.merge_to_parameters_from_dict(p, p_dict)
コード例 #5
0
 def dist_params_dict(d):
     return distribution_utils.parameters_to_dict(
         distribution_utils.get_parameters(d), tensors_only=True)
コード例 #6
0
 def dist_params_dict(d):
     return distribution_utils.parameters_to_dict(
         distribution_utils.get_parameters(d))