def testGetAndMakeFromParameters(self): one = tf.constant(1.0) d = tfp.distributions.Normal(loc=one, scale=3.0, validate_args=True) d = tfp.bijectors.Tanh()(d) d = tfp.bijectors.Tanh()(d) p = utils.get_parameters(d) expected_p = utils.Params( tfp.distributions.TransformedDistribution, params={ 'bijector': utils.Params(tfp.bijectors.Chain, params={ 'bijectors': [ utils.Params(tfp.bijectors.Tanh, params={}), utils.Params(tfp.bijectors.Tanh, params={}), ] }), 'distribution': utils.Params(tfp.distributions.Normal, params={ 'validate_args': True, 'scale': 3.0, 'loc': one }) }) self.compare_params(p, expected_p) d_recreated = utils.make_from_parameters(p) points = [0.01, 0.25, 0.5, 0.75, 0.99] self.assertAllClose(d.log_prob(points), d_recreated.log_prob(points))
def _convert_to_spec_and_remove_singleton_batch_dim( parameters: distribution_utils.Params, outer_ndim: int) -> distribution_utils.Params: """Convert a `Params` object of tensors to one containing unbatched specs. Note: The `Params` provided to this function are typically contain tensors generated by Layers and therefore containing an outer singleton dimension. Since TF-Agents specs exclude batch and time prefixes, here we need to remove the singleton batch dimension from the specs created by these input tensors. Args: parameters: Distribution parameters, including input tensors. outer_ndim: Number of singleton outer dimensions expected in tensors found in `parameters`. Returns: A `Params` object contanining `tf.TypeSpec` in place of tensors, with up to `outer_ndim` outer singleton dimensions removed. """ def _maybe_convert_to_spec(p): if isinstance(p, distribution_utils.Params): return _convert_to_spec_and_remove_singleton_batch_dim( p, outer_ndim) elif tf.is_tensor(p): return nest_utils.remove_singleton_batch_spec_dim( tf.type_spec_from_value(p), outer_ndim=outer_ndim) else: return p return distribution_utils.Params(type_=parameters.type_, params=tf.nest.map_structure( _maybe_convert_to_spec, parameters.params))
def testGetAndMakeNontrivialBijectorFromParameters(self): scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]]) d = tfp.distributions.MultivariateNormalDiag(loc=[1.0, 1.0], scale_diag=[2.0, 3.0], validate_args=True) b = tfp.bijectors.ScaleMatvecLinearOperator( scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix), adjoint=True) b_d = b(d) p = utils.get_parameters(b_d) expected_p = utils.Params( tfp.distributions.TransformedDistribution, params={ 'bijector': utils.Params(tfp.bijectors.ScaleMatvecLinearOperator, params={ 'adjoint': True, 'scale': utils.Params( tf.linalg.LinearOperatorFullMatrix, params={'matrix': scale_matrix}) }), 'distribution': utils.Params(tfp.distributions.MultivariateNormalDiag, params={ 'validate_args': True, 'scale_diag': [2.0, 3.0], 'loc': [1.0, 1.0] }) }) self.compare_params(p, expected_p) b_d_recreated = utils.make_from_parameters(p) points = [[-1.0, -2.0], [0.0, 0.0], [3.0, -5.0], [5.0, 5.0], [1.0, np.inf], [-np.inf, 0.0]] self.assertAllClose(b_d.log_prob(points), b_d_recreated.log_prob(points))
class CreateVariablesTest(parameterized.TestCase, tf.test.TestCase): def testNetworkCreate(self): observation_spec = specs.TensorSpec([1], tf.float32, 'observation') action_spec = specs.TensorSpec([2], tf.float32, 'action') net = MockNetwork(observation_spec, action_spec) self.assertFalse(net.built) with self.assertRaises(ValueError): net.variables # pylint: disable=pointless-statement output_spec = network.create_variables(net) # MockNetwork adds some variables to observation, which has shape [bs, 1] self.assertEqual(output_spec, tf.TensorSpec([1], dtype=tf.float32)) self.assertTrue(net.built) self.assertLen(net.variables, 2) self.assertLen(net.trainable_variables, 1) # pylint: disable=g-long-lambda @parameterized.named_parameters( ( 'Dense', lambda: tf.keras.layers.Dense(3), tf.TensorSpec((5,), tf.float32), # input_spec tf.TensorSpec((3,), tf.float32), # expected_output_spec (), # expected_state_spec ), ( 'LSTMCell', lambda: tf.keras.layers.LSTMCell(3), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), [tf.TensorSpec((3,), tf.float32), tf.TensorSpec((3,), tf.float32)], ), ( 'LSTMCellInRNN', lambda: rnn_wrapper.RNNWrapper( tf.keras.layers.RNN( tf.keras.layers.LSTMCell(3), return_state=True, return_sequences=True) ), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), [tf.TensorSpec((3,), tf.float32), tf.TensorSpec((3,), tf.float32)], ), ( 'LSTM', lambda: rnn_wrapper.RNNWrapper( tf.keras.layers.LSTM( 3, return_state=True, return_sequences=True) ), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), [tf.TensorSpec((3,), tf.float32), tf.TensorSpec((3,), tf.float32)], ), ( 'TimeDistributed', lambda: tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), () ), ( 'Conv2D', lambda: tf.keras.layers.Conv2D(2, 3), tf.TensorSpec((28, 28, 5), tf.float32), tf.TensorSpec((26, 26, 2), tf.float32), () ), ( 'SequentialOfDense', lambda: tf.keras.Sequential([tf.keras.layers.Dense(3)] * 2), tf.TensorSpec((5,), tf.float32), tf.TensorSpec((3,), tf.float32), () ), ( 'NormalDistribution', lambda: tf.keras.Sequential( [tf.keras.layers.Dense(3), tf.keras.layers.Lambda( lambda x: tfd.Normal(loc=x, scale=x**2))]), tf.TensorSpec((5,), tf.float32), distribution_utils.DistributionSpecV2( event_shape=tf.TensorShape(()), dtype=tf.float32, parameters=distribution_utils.Params( type_=tfd.Normal, params=dict( loc=tf.TensorSpec((3,), tf.float32), scale=tf.TensorSpec((3,), tf.float32), ))), () ), ) # pylint: enable=g-long-λ def testKerasLayerCreate(self, layer_fn, input_spec, expected_output_spec, expected_state_spec): layer = layer_fn() with self.assertRaisesRegex(ValueError, 'an input_spec is required'): network.create_variables(layer) output_spec = network.create_variables(layer, input_spec) self.assertTrue(layer.built) self.assertEqual( output_spec, expected_output_spec, '\n{}\nvs.\n{}\n'.format(output_spec, expected_output_spec)) output_spec_2 = network.create_variables(layer, input_spec) self.assertEqual(output_spec_2, expected_output_spec) state_spec = getattr(layer, '_network_state_spec', None) self.assertEqual(state_spec, expected_state_spec)