예제 #1
0
 def _calc_unbatched_spec(x):
     if isinstance(x, tfp.distributions.Distribution):
         parameters = distribution_utils.get_parameters(x)
         parameter_specs = _convert_to_spec_and_remove_singleton_batch_dim(
             parameters, outer_ndim=outer_ndim)
         return distribution_utils.DistributionSpecV2(
             event_shape=x.event_shape,
             dtype=x.dtype,
             parameters=parameter_specs)
     else:
         return nest_utils.remove_singleton_batch_spec_dim(
             tf.type_spec_from_value(x), outer_ndim=outer_ndim)
예제 #2
0
        def _calc_unbatched_spec(x):
            """Build Network output spec by removing previously added batch dimension.

      Args:
        x: tfp.distributions.Distribution or Tensor.
      Returns:
        Specs without batch dimension representing x.
      """
            if isinstance(x, tfp.distributions.Distribution):
                parameters = distribution_utils.get_parameters(x)
                parameter_specs = _convert_to_spec_and_remove_singleton_batch_dim(
                    parameters, outer_ndim=1)
                return distribution_utils.DistributionSpecV2(
                    event_shape=x.event_shape,
                    dtype=x.dtype,
                    parameters=parameter_specs)
            else:
                return tensor_spec.remove_outer_dims_nest(
                    tf.type_spec_from_value(x), num_outer_dims=1)
예제 #3
0
class CreateVariablesTest(parameterized.TestCase, tf.test.TestCase):

  def testNetworkCreate(self):
    observation_spec = specs.TensorSpec([1], tf.float32, 'observation')
    action_spec = specs.TensorSpec([2], tf.float32, 'action')
    net = MockNetwork(observation_spec, action_spec)
    self.assertFalse(net.built)
    with self.assertRaises(ValueError):
      net.variables  # pylint: disable=pointless-statement
    output_spec = network.create_variables(net)
    # MockNetwork adds some variables to observation, which has shape [bs, 1]
    self.assertEqual(output_spec, tf.TensorSpec([1], dtype=tf.float32))
    self.assertTrue(net.built)
    self.assertLen(net.variables, 2)
    self.assertLen(net.trainable_variables, 1)

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      (
          'Dense',
          lambda: tf.keras.layers.Dense(3),
          tf.TensorSpec((5,), tf.float32),  # input_spec
          tf.TensorSpec((3,), tf.float32),  # expected_output_spec
          (),  # expected_state_spec
      ),
      (
          'LSTMCell',
          lambda: tf.keras.layers.LSTMCell(3),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          [tf.TensorSpec((3,), tf.float32),
           tf.TensorSpec((3,), tf.float32)],
      ),
      (
          'LSTMCellInRNN',
          lambda: rnn_wrapper.RNNWrapper(
              tf.keras.layers.RNN(
                  tf.keras.layers.LSTMCell(3),
                  return_state=True,
                  return_sequences=True)
          ),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          [tf.TensorSpec((3,), tf.float32),
           tf.TensorSpec((3,), tf.float32)],
      ),
      (
          'LSTM',
          lambda: rnn_wrapper.RNNWrapper(
              tf.keras.layers.LSTM(
                  3,
                  return_state=True,
                  return_sequences=True)
          ),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          [tf.TensorSpec((3,), tf.float32),
           tf.TensorSpec((3,), tf.float32)],
      ),
      (
          'TimeDistributed',
          lambda: tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          ()
      ),
      (
          'Conv2D',
          lambda: tf.keras.layers.Conv2D(2, 3),
          tf.TensorSpec((28, 28, 5), tf.float32),
          tf.TensorSpec((26, 26, 2), tf.float32),
          ()
      ),
      (
          'SequentialOfDense',
          lambda: tf.keras.Sequential([tf.keras.layers.Dense(3)] * 2),
          tf.TensorSpec((5,), tf.float32),
          tf.TensorSpec((3,), tf.float32),
          ()
      ),
      (
          'NormalDistribution',
          lambda: tf.keras.Sequential(
              [tf.keras.layers.Dense(3),
               tf.keras.layers.Lambda(
                   lambda x: tfd.Normal(loc=x, scale=x**2))]),
          tf.TensorSpec((5,), tf.float32),
          distribution_utils.DistributionSpecV2(
              event_shape=tf.TensorShape(()),
              dtype=tf.float32,
              parameters=distribution_utils.Params(
                  type_=tfd.Normal,
                  params=dict(
                      loc=tf.TensorSpec((3,), tf.float32),
                      scale=tf.TensorSpec((3,), tf.float32),
                  ))),
          ()
      ),
  )
  # pylint: enable=g-long-λ
  def testKerasLayerCreate(self, layer_fn, input_spec, expected_output_spec,
                           expected_state_spec):
    layer = layer_fn()
    with self.assertRaisesRegex(ValueError, 'an input_spec is required'):
      network.create_variables(layer)
    output_spec = network.create_variables(layer, input_spec)
    self.assertTrue(layer.built)
    self.assertEqual(
        output_spec, expected_output_spec,
        '\n{}\nvs.\n{}\n'.format(output_spec, expected_output_spec))
    output_spec_2 = network.create_variables(layer, input_spec)
    self.assertEqual(output_spec_2, expected_output_spec)
    state_spec = getattr(layer, '_network_state_spec', None)
    self.assertEqual(state_spec, expected_state_spec)