def testCreateAgentWithPrebuiltPreprocessingLayers(self, agent_class): dense_layer = tf.keras.layers.Dense(2) q_net = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, dense_layer) with self.assertRaisesRegexp( ValueError, 'shares weights with the original network'): agent_class( self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None) # Explicitly share weights between q and target networks. # This would be an unusual setup so we check that an error is thrown. q_target_net = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, dense_layer) with self.assertRaisesRegexp( ValueError, 'shares weights with the original network'): agent_class( self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None, target_q_network=q_target_net)
def test_check_matching_networks(self): layer_1 = tf.keras.layers.Dense(3) layer_2 = tf.keras.layers.Dense(3) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_1) q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_2) common.check_matching_networks(q_net_1, q_net_2)
def test_check_no_shared_variables_expect_fail(self): dense_layer = tf.keras.layers.Dense(3) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, dense_layer) q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, dense_layer) with self.assertRaises(ValueError): common.check_no_shared_variables(q_net_1, q_net_2)
def test_check_no_shared_variables(self): layer_1 = tf.keras.layers.Dense(3) layer_2 = tf.keras.layers.Dense(3) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_1) q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_2) q_net_1.create_variables() q_net_2.create_variables() common.check_no_shared_variables(q_net_1, q_net_2)
def test_check_matching_networks_different_vars(self): layer_1 = tf.keras.layers.Dense(3) layer_2 = tf.keras.layers.GRU(3) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_1) q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_2) q_net_1.create_variables() q_net_2.create_variables() with self.assertRaisesRegexp(ValueError, 'Variables lengths do not match'): common.check_matching_networks(q_net_1, q_net_2)
def test_check_matching_networks_different_shape(self): layer_1 = tf.keras.layers.Dense(3) layer_2 = tf.keras.layers.Dense(4) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_1) q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_2) with self.assertRaisesRegexp(ValueError, 'Variable dtypes or shapes do not match'): common.check_matching_networks(q_net_1, q_net_2)
def test_check_matching_networks_different_input_spec(self): layer_1 = tf.keras.layers.Dense(3) layer_2 = tf.keras.layers.Dense(3) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_1) q_net_2 = networks_test_utils.KerasLayersNet( tensor_spec.TensorSpec([3], tf.float32), self._action_spec, layer_2) with self.assertRaisesRegexp( ValueError, 'Input tensor specs of network and target network ' 'do not match'): common.check_matching_networks(q_net_1, q_net_2)
def testCreateAgentWithPrebuiltPreprocessingLayers(self, agent_class): dense_layer = tf.keras.layers.Dense(3) q_net = networks_test_utils.KerasLayersNet(self._observation_spec[0], self._action_spec, dense_layer) with self.assertRaisesRegexp( ValueError, 'shares weights with the original network'): agent_class(self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None) # Explicitly share weights between q and target networks; this is ok. q_target_net = networks_test_utils.KerasLayersNet( self._observation_spec[0], self._action_spec, dense_layer) agent_class(self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None, target_q_network=q_target_net)