コード例 #1
0
ファイル: network_test.py プロジェクト: yangjue-han/agents
 def testKerasLayerCreate(self, layer_fn, input_spec, expected_output_spec):
   layer = layer_fn()
   with self.assertRaisesRegex(ValueError, 'an input_spec is required'):
     network.create_variables(layer)
   output_spec = network.create_variables(layer, input_spec)
   self.assertTrue(layer.built)
   self.assertEqual(output_spec, expected_output_spec)
   output_spec_2 = network.create_variables(layer, input_spec)
   self.assertEqual(output_spec_2, expected_output_spec)
コード例 #2
0
    def testKerasLayerFailsIfRecurrentDoesNotReturnState(self):
        with self.assertRaisesRegex(ValueError, 'with return_state==False'):
            network.create_variables(
                tf.keras.layers.LSTM(3, return_sequences=True),
                input_spec=tf.TensorSpec((3, ), tf.float32))

        with self.assertRaisesRegex(ValueError,
                                    'with return_sequences==False'):
            network.create_variables(tf.keras.layers.LSTM(3,
                                                          return_state=True),
                                     input_spec=tf.TensorSpec((3, ),
                                                              tf.float32))
コード例 #3
0
 def testKerasLayerCreate(self, layer_fn, input_spec, expected_output_spec,
                          expected_state_spec):
   layer = layer_fn()
   with self.assertRaisesRegex(ValueError, 'an input_spec is required'):
     network.create_variables(layer)
   output_spec = network.create_variables(layer, input_spec)
   self.assertTrue(layer.built)
   self.assertEqual(
       output_spec, expected_output_spec,
       '\n{}\nvs.\n{}\n'.format(output_spec, expected_output_spec))
   output_spec_2 = network.create_variables(layer, input_spec)
   self.assertEqual(output_spec_2, expected_output_spec)
   state_spec = getattr(layer, '_network_state_spec', None)
   self.assertEqual(state_spec, expected_state_spec)
コード例 #4
0
ファイル: sequential.py プロジェクト: tfboyd/agents
def _infer_specs(
    layers: typing.Sequence[tf.keras.layers.Layer],
    input_spec: types.NestedTensorSpec
) -> typing.Tuple[
    types.NestedTensorSpec,
    types.NestedTensorSpec
]:
  """Infer the state spec of a sequence of keras Layers and Networks.

  This runs `create_variables` on each layer, and identifies the
  state spec from each.  Running `create_variables` is necessary
  because this creates a `_network_state_spec` property on each
  generic (non-Network) Keras layer.

  Args:
    layers: A list of Keras layers and Network.
    input_spec: The input to the first laayer.

  Returns:
    A tuple `(output_spec, state_spec)` where `output_spec` is the output spec
    from the final layer and `state_spec` is a tuple of the state specs.
  """
  state_specs = []

  output_spec = input_spec
  for layer in layers:
    output_spec = network.create_variables(layer, output_spec)
    state_spec = network.get_state_spec(layer)
    state_specs.append(state_spec)

  state_specs = tuple(state_specs)
  return output_spec, state_specs
コード例 #5
0
 def testNetworkCreate(self):
   observation_spec = specs.TensorSpec([1], tf.float32, 'observation')
   action_spec = specs.TensorSpec([2], tf.float32, 'action')
   net = MockNetwork(observation_spec, action_spec)
   self.assertFalse(net.built)
   with self.assertRaises(ValueError):
     net.variables  # pylint: disable=pointless-statement
   output_spec = network.create_variables(net)
   # MockNetwork adds some variables to observation, which has shape [bs, 1]
   self.assertEqual(output_spec, tf.TensorSpec([1], dtype=tf.float32))
   self.assertTrue(net.built)
   self.assertLen(net.variables, 2)
   self.assertLen(net.trainable_variables, 1)
コード例 #6
0
 def create_variables(self, input_spec=None):
     output_spec = network.create_variables(
         self._layer, input_spec or self._input_tensor_spec)
     self._network_output_spec = output_spec
     self.built = True
     return output_spec