Exemplo n.º 1
0
 def test_check_matching_networks(self):
     layer_1 = tf.keras.layers.Dense(3)
     layer_2 = tf.keras.layers.Dense(3)
     q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec,
                                                  self._action_spec,
                                                  layer_1)
     q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec,
                                                  self._action_spec,
                                                  layer_2)
     common.check_matching_networks(q_net_1, q_net_2)
Exemplo n.º 2
0
 def test_check_matching_networks_different_vars(self):
   layer_1 = tf.keras.layers.Dense(3)
   layer_2 = tf.keras.layers.GRU(3)
   q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec,
                                                self._action_spec, layer_1)
   q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec,
                                                self._action_spec, layer_2)
   q_net_1.create_variables()
   q_net_2.create_variables()
   with self.assertRaisesRegexp(ValueError, 'Variables lengths do not match'):
     common.check_matching_networks(q_net_1, q_net_2)
Exemplo n.º 3
0
 def test_check_matching_networks_different_shape(self):
     layer_1 = tf.keras.layers.Dense(3)
     layer_2 = tf.keras.layers.Dense(4)
     q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec,
                                                  self._action_spec,
                                                  layer_1)
     q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec,
                                                  self._action_spec,
                                                  layer_2)
     with self.assertRaisesRegexp(ValueError,
                                  'Variable dtypes or shapes do not match'):
         common.check_matching_networks(q_net_1, q_net_2)
Exemplo n.º 4
0
 def test_check_matching_networks_different_input_spec(self):
     layer_1 = tf.keras.layers.Dense(3)
     layer_2 = tf.keras.layers.Dense(3)
     q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec,
                                                  self._action_spec,
                                                  layer_1)
     q_net_2 = networks_test_utils.KerasLayersNet(
         tensor_spec.TensorSpec([3], tf.float32), self._action_spec,
         layer_2)
     with self.assertRaisesRegexp(
             ValueError, 'Input tensor specs of network and target network '
             'do not match'):
         common.check_matching_networks(q_net_1, q_net_2)