Example #1
0
    def testGetAndMakeFromParameters(self):
        one = tf.constant(1.0)
        d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True)
        d = tfp.bijectors.Tanh()(d)
        d = tfp.bijectors.Tanh()(d)
        p = utils.get_parameters(d)

        expected_p = utils.Params(
            tfp.distributions.TransformedDistribution,
            params={
                'bijector':
                utils.Params(tfp.bijectors.Chain,
                             params={
                                 'bijectors': [
                                     utils.Params(tfp.bijectors.Tanh,
                                                  params={}),
                                     utils.Params(tfp.bijectors.Tanh,
                                                  params={}),
                                 ]
                             }),
                'distribution':
                utils.Params(tfp.distributions.Normal,
                             params={
                                 'validate_args': True,
                                 'scale': 3.0,
                                 'loc': one
                             })
            })

        self.compare_params(p, expected_p)

        d_recreated = utils.make_from_parameters(p)
        points = [0.01, 0.25, 0.5, 0.75, 0.99]
        self.assertAllClose(d.log_prob(points), d_recreated.log_prob(points))
Example #2
0
    def testParametersFromDictMissingNestedDictKeyFailure(self):
        one = tf.constant(1.0)
        d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True)
        d = tfp.bijectors.Tanh()(d)
        d = tfp.bijectors.Tanh()(d)
        p = utils.get_parameters(d)
        p_dict = utils.parameters_to_dict(p)

        # Remove a non-nested key in the dictionary; this is fine.
        del p_dict['distribution']['validate_args']

        # We can reconstruct from this (we just use the default value from p)
        utils.merge_to_parameters_from_dict(p, p_dict)

        # Remove a nested entry in the dictionary; this can lead to subtle errors so
        # we don't allow it.
        del p_dict['bijector']['bijectors:1']

        # Flattening nested params lost information about the nested structure, so
        # we can't e.g. remove a bijector from a list and override just a subset of
        # the nested bijectors list.
        with self.assertRaisesRegex(
                KeyError,
                r'Only saw partial information from the dictionary for nested key '
                r'\'bijectors\' in params_dict.*'
                r'Entries provided: \[\'bijectors:0\'\].*'
                r'Entries required: \[\'bijectors:0\', \'bijectors:1\'\]'):
            utils.merge_to_parameters_from_dict(p, p_dict)
Example #3
0
    def testParametersToAndFromDict(self, tensors_only):
        scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]])
        d = tfp.distributions.MultivariateNormalDiag(loc=[1.0, 1.0],
                                                     scale_diag=[2.0, 3.0],
                                                     validate_args=True)
        b = tfp.bijectors.ScaleMatvecLinearOperator(
            scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix),
            adjoint=True)
        b_d = b(d)
        p = utils.get_parameters(b_d)

        p_dict = utils.parameters_to_dict(p, tensors_only=tensors_only)

        if tensors_only:
            expected_p_dict = {
                'bijector': {
                    'scale': {
                        'matrix': scale_matrix
                    }
                },
                'distribution': {},
            }
        else:
            expected_p_dict = {
                'bijector': {
                    'adjoint': True,
                    'scale': {
                        'matrix': scale_matrix
                    }
                },
                'distribution': {
                    'validate_args': True,
                    # These are deeply nested because we passed lists
                    # intead of numpy arrays for `loc` and `scale_diag`.
                    'scale_diag:0': 2.0,
                    'scale_diag:1': 3.0,
                    'loc:0': 1.0,
                    'loc:1': 1.0
                }
            }

        tf.nest.map_structure(self.assertAllEqual, p_dict, expected_p_dict)

        # This converts the tf.Variable entry in the matrix to a tf.Tensor
        p_dict['bijector']['scale']['matrix'] = (
            p_dict['bijector']['scale']['matrix'] + 1.0)

        # When tensors_only=True, we make sure that we can merge into p
        # from a dict where we dropped everything but tensors.
        p_recreated = utils.merge_to_parameters_from_dict(p, p_dict)

        self.assertAllClose(
            p_recreated.params['bijector'].params['scale'].params['matrix'],
            p.params['bijector'].params['scale'].params['matrix'] + 1.0)

        # Skip the tensor value comparison -- we checked it above.
        self.compare_params(p, p_recreated, skip_tensor_values=True)
Example #4
0
 def _calc_unbatched_spec(x):
     if isinstance(x, tfp.distributions.Distribution):
         parameters = distribution_utils.get_parameters(x)
         parameter_specs = _convert_to_spec_and_remove_singleton_batch_dim(
             parameters, outer_ndim=outer_ndim)
         return distribution_utils.DistributionSpecV2(
             event_shape=x.event_shape,
             dtype=x.dtype,
             parameters=parameter_specs)
     else:
         return nest_utils.remove_singleton_batch_spec_dim(
             tf.type_spec_from_value(x), outer_ndim=outer_ndim)
Example #5
0
        def _calc_unbatched_spec(x):
            """Build Network output spec by removing previously added batch dimension.

      Args:
        x: tfp.distributions.Distribution or Tensor.
      Returns:
        Specs without batch dimension representing x.
      """
            if isinstance(x, tfp.distributions.Distribution):
                parameters = distribution_utils.get_parameters(x)
                parameter_specs = _convert_to_spec_and_remove_singleton_batch_dim(
                    parameters, outer_ndim=1)
                return distribution_utils.DistributionSpecV2(
                    event_shape=x.event_shape,
                    dtype=x.dtype,
                    parameters=parameter_specs)
            else:
                return tensor_spec.remove_outer_dims_nest(
                    tf.type_spec_from_value(x), num_outer_dims=1)
Example #6
0
    def testGetAndMakeNontrivialBijectorFromParameters(self):
        scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]])
        d = tfp.distributions.MultivariateNormalDiag(loc=[1.0, 1.0],
                                                     scale_diag=[2.0, 3.0],
                                                     validate_args=True)
        b = tfp.bijectors.ScaleMatvecLinearOperator(
            scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix),
            adjoint=True)
        b_d = b(d)
        p = utils.get_parameters(b_d)

        expected_p = utils.Params(
            tfp.distributions.TransformedDistribution,
            params={
                'bijector':
                utils.Params(tfp.bijectors.ScaleMatvecLinearOperator,
                             params={
                                 'adjoint':
                                 True,
                                 'scale':
                                 utils.Params(
                                     tf.linalg.LinearOperatorFullMatrix,
                                     params={'matrix': scale_matrix})
                             }),
                'distribution':
                utils.Params(tfp.distributions.MultivariateNormalDiag,
                             params={
                                 'validate_args': True,
                                 'scale_diag': [2.0, 3.0],
                                 'loc': [1.0, 1.0]
                             })
            })

        self.compare_params(p, expected_p)

        b_d_recreated = utils.make_from_parameters(p)

        points = [[-1.0, -2.0], [0.0, 0.0], [3.0, -5.0], [5.0, 5.0],
                  [1.0, np.inf], [-np.inf, 0.0]]
        self.assertAllClose(b_d.log_prob(points),
                            b_d_recreated.log_prob(points))
Example #7
0
    def testParametersFromDictMissingNestedParamsKeyFailure(self):
        one = tf.constant(1.0)
        d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True)
        d = tfp.bijectors.Tanh()(d)
        d = tfp.bijectors.Tanh()(d)
        p = utils.get_parameters(d)
        p_dict = utils.parameters_to_dict(p)

        # Add a third bijector, changing the structure of the nest.
        self.assertIn('bijectors:0', p_dict['bijector'].keys())
        self.assertIn('bijectors:1', p_dict['bijector'].keys())
        p_dict['bijector']['bijectors:2'] = p_dict['bijector']['bijectors:0']

        # Flattening nested params lost information about the nested structure, so
        # we can't e.g. add a new bijector in the dict and expect to put that back
        # into the bijector list when converting back.
        with self.assertRaisesRegex(
                ValueError,
                r'params_dict keys: \[\'bijectors:0\', \'bijectors:1\', '
                r'\'bijectors:2\'\], value.params processed keys: '
                r'\[\'bijectors:0\', \'bijectors:1\'\]'):
            utils.merge_to_parameters_from_dict(p, p_dict)
 def dist_params_dict(d):
     return distribution_utils.parameters_to_dict(
         distribution_utils.get_parameters(d), tensors_only=True)
Example #9
0
 def dist_params_dict(d):
     return distribution_utils.parameters_to_dict(
         distribution_utils.get_parameters(d))