def testActionSelectionWithEpsilonDecay(self): policy = py_epsilon_greedy_policy.EpsilonGreedyPolicy( self.greedy_policy, 0.9, random_policy=self.random_policy, epsilon_decay_end_count=10, epsilon_decay_end_value=0.4) time_step = mock.MagicMock() # Replace the random generator with fixed behaviour random = mock.MagicMock() policy._rng = random # 0.8 < 0.9 and 0.8 < 0.85, so random policy should be used. policy._rng.rand.return_value = 0.8 for _ in range(2): policy.action(time_step) self.random_policy.action.assert_called_with(time_step) self.assertEqual(2, self.random_policy.action.call_count) self.assertEqual(0, self.greedy_policy.action.call_count) # epislon will change from [0.8 to 0.4], and greedy policy should be used for _ in range(8): policy.action(time_step) self.greedy_policy.action.assert_called_with(time_step, policy_state=()) self.assertEqual(2, self.random_policy.action.call_count) self.assertEqual(8, self.greedy_policy.action.call_count) # 0.399 < 0.4, random policy should be used. policy._rng.rand.return_value = 0.399 self.random_policy.reset_mock() for _ in range(5): policy.action(time_step) self.random_policy.action.assert_called_with(time_step) self.assertEqual(5, self.random_policy.action.call_count) # greedy policy should not be called any more self.assertEqual(8, self.greedy_policy.action.call_count)
def test_logged_retry_on_retriable_http_error_reraises_error(self): error = errors.HttpError(mock.MagicMock(), _BYTES_ERROR_MESSAGE) mocked_decorated_function_that_throws_error = mock.MagicMock( side_effect=error) decorated_func = retry_utils.logged_retry_on_retriable_http_error( mocked_decorated_function_that_throws_error) with self.assertRaises(errors.HttpError): decorated_func()
def __init__(self, data_generator=None, table_generator=None, fields=None): self.project_id = 'test_project' self.data_generator = data_generator self.table_generator = table_generator self.fields = fields if table_generator: self.service = mock.MagicMock() self.service.tables = mock.MagicMock() self.service.tables.return_value = mock.MagicMock() self.service.tables().list = mock.MagicMock() self.service.tables().list.return_value = mock.MagicMock() self.service.tables().list().execute = self.service_tables_list
def setUp(self): super(EpsilonGreedyPolicyTest, self).setUp() self.greedy_policy = mock.MagicMock() self.random_policy = mock.MagicMock() self.greedy_policy.time_step_spec = ts.time_step_spec() self.greedy_policy.action_spec = () self.greedy_policy.info_spec = () self.greedy_policy.policy_state_spec = () self.random_policy.time_step_spec = ts.time_step_spec() self.random_policy.action_spec = () self.random_policy.info_spec = () self.random_policy.policy_state_spec = () self.random_policy.action.return_value = policy_step.PolicyStep(0, ())
def _setup_mocks(self): self.trainer = train_eval_atari.TrainEval(self.get_temp_dir(), 'Pong-v0', terminal_on_life_loss=True) self.trainer._env = mock.MagicMock() self.trainer._env.envs[0].game_over = False self.trainer._replay_buffer = mock.MagicMock() self.trainer._collect_policy = mock.MagicMock() action_step = policy_step.PolicyStep(action=1) self.trainer._collect_policy.action.return_value = action_step self.observer = mock.MagicMock() self.metric_observers = [self.observer]
def setUp(self, mocked_hook): super(BigqueryHookTest, self).setUp() mocked_hook.return_value = mock.MagicMock(bigquery_conn_id='test_conn', autospec=True) bq_hook.BigQueryHook._get_field = mock.MagicMock() bq_hook.BigQueryHook._get_field.return_value = 'test_project' self.hook = bq_hook.BigQueryHook( conn_id='test_conn', dataset_id='test_dataset', table_id='test_table', ) self.hook.get_conn = mock.MagicMock() self.hook.get_conn.return_value = mock.MagicMock() self.hook.get_conn.cursor = mock.MagicMock() self.error_hook = bq_hook.BigQueryHook( conn_id='test_conn', dataset_id='test_dataset', table_id='error_table', ) self.error_hook.get_conn = mock.MagicMock() self.error_hook.get_conn.return_value = mock.MagicMock() self.error_hook.get_conn.cursor = mock.MagicMock() bq_hook._DEFAULT_PAGE_SIZE = 1 self.fields = [{ 'name': 'a', 'type': 'STRING' }, { 'name': 'b', 'type': 'STRING' }]
def test_logged_retry_on_retriable_http_error_no_retry_on_status(self, status): error = errors.HttpError(mock.MagicMock(status=status), _BYTES_ERROR_MESSAGE) mocked_decorated_function_that_throws_error = mock.MagicMock( side_effect=error) decorated_func = retry_utils.logged_retry_on_retriable_http_error( mocked_decorated_function_that_throws_error) try: decorated_func() except errors.HttpError: pass mocked_decorated_function_that_throws_error.assert_called_once()
def test_logged_retry_on_retriable_http_error_retry_on_status(self, status): error = errors.HttpError(mock.MagicMock(status=status), _BYTES_ERROR_MESSAGE) mocked_decorated_function_that_throws_error = mock.MagicMock( side_effect=error) decorated_func = retry_utils.logged_retry_on_retriable_http_error( mocked_decorated_function_that_throws_error) try: decorated_func() except errors.HttpError: pass self.assertEqual(mocked_decorated_function_that_throws_error.call_count, retry_utils._RETRY_UTILS_MAX_RETRIES)
def test_tpu_strategy(self, mock_tpu_cluster_resolver, mock_experimental_connect_to_cluster, mock_initialize_tpu_system, mock_tpu_strategy): resolver = mock.MagicMock() mock_tpu_cluster_resolver.return_value = resolver mock_strategy = mock.MagicMock() mock_tpu_strategy.return_value = mock_strategy strategy = strategy_utils.get_strategy(tpu='bns_address', use_gpu=False) mock_tpu_cluster_resolver.assert_called_with(tpu='bns_address') mock_experimental_connect_to_cluster.assert_called_with(resolver) mock_initialize_tpu_system.assert_called_with(resolver) self.assertIs(strategy, mock_strategy)
def setUp(self): """Setup function for each unit test.""" super(GoogleAnalyticsHookTest, self).setUp() self.test_tracking_id = 'UA-12323-4' self.payload_builder = ga_hook.PayloadBuilder(self.test_tracking_id) self.event_test_data = { 'ec': 'ClientID', 'ea': 'test_event_action', 'el': '20190423', 'ev': 1, 'cid': '12345.456789' } self.small_event = { 'cid': '12345.67890', 'ec': 'ClientID', 'ea': 'test_event_action', 'el': '20190423', 'ev': 1, 'z': '1558517072202080' } # Both of the below are approx 4K of data self.medium_event = {**self.small_event, 'ea': 'x' * 3800} self.utf8_event = {**self.small_event, 'ea': b'\xf0\xa9\xb8\xbd' * 320} self.test_hook = ga_hook.GoogleAnalyticsHook(self.test_tracking_id, self.event_test_data) self.test_hook._send_http_request = mock.MagicMock(autospec=True) self.test_hook._send_http_request.return_value = mock.Mock(ok=True) self.test_hook._send_http_request.return_value.status = 200
def test_render_kwargs(self): cartpole_env = gym.spec('CartPole-v1').make() def _gym_render_mock(mode='rgb_array', **kwargs): del mode, kwargs # unused return None cartpole_env.render = mock.MagicMock(side_effect=_gym_render_mock) # The following are hypothetical (mocked) usages of render(), # as the CartPole environment actually does not support kwargs for `render`. env = gym_wrapper.GymWrapper(cartpole_env, render_kwargs={ 'width': 96, 'height': 128 }) env.render() cartpole_env.render.assert_called_with('rgb_array', width=96, height=128) env.render(mode='human') cartpole_env.render.assert_called_with('human', width=96, height=128) # In a corner case where kwargs includes mode (got multiple values for # argument 'mode'), an exception will be thrown. cartpole_env.render.reset_mock() render_kwargs = dict(mode='human', width=96, height=128) env = gym_wrapper.GymWrapper(cartpole_env, render_kwargs=render_kwargs) with self.assertRaisesRegex(TypeError, 'multiple values'): env.render()
def test_wrapped_method_propagation(self): mock_env = mock.MagicMock() env = tf_wrappers.TFEnvironmentBaseWrapper(mock_env) env.time_step_spec() self.assertEqual(1, mock_env.time_step_spec.call_count) env.action_spec() self.assertEqual(1, mock_env.action_spec.call_count) env.observation_spec() self.assertEqual(1, mock_env.observation_spec.call_count) env.batched() self.assertEqual(1, mock_env.batched.call_count) env.batch_size() self.assertEqual(1, mock_env.batch_size.call_count) env.current_time_step() self.assertEqual(1, mock_env.current_time_step.call_count) env.reset() self.assertEqual(1, mock_env.reset.call_count) env.step(0) self.assertEqual(1, mock_env.step.call_count) mock_env.step.assert_called_with(0) env.render() self.assertEqual(1, mock_env.render.call_count)
def test_obs_dtype(self): cartpole_env = gym.spec('CartPole-v1').make() cartpole_env.render = mock.MagicMock() env = gym_wrapper.GymWrapper(cartpole_env) time_step = env.reset() self.assertEqual(env.observation_spec().dtype, time_step.observation.dtype)
def test_wrap_gym_render_kwargs(self): gym_env = gym.make('CartPole-v1') env = suite_gym.wrap_env(gym_env, render_kwargs={'width': 96, 'height': 128}) gym_env.render = mock.MagicMock() # render_kwargs should be passed to them underlying gym env's render(). env.render() gym_env.render.assert_called_with('rgb_array', width=96, height=128)
def test_execute_request_retries_on_service_unavailable_http_error(self): mock_request = mock.Mock(http.HttpRequest) content = b'' error = errors.HttpError(mock.MagicMock(status=503), content) mock_request.execute.side_effect = [error, None] utils.execute_request(mock_request) self.assertEqual(mock_request.execute.call_count, 2)
def testActionAlwaysRandom(self): policy = py_epsilon_greedy_policy.EpsilonGreedyPolicy( self.greedy_policy, 1, random_policy=self.random_policy) time_step = mock.MagicMock() for _ in range(5): policy.action(time_step) self.random_policy.action.assert_called_with(time_step) self.assertEqual(5, self.random_policy.action.call_count) self.assertEqual(0, self.greedy_policy.action.call_count)
def _get_mock_env_episode(self): mock_env = mock.MagicMock() mock_env.step.side_effect = [ ts.TimeStep(ts.StepType.FIRST, 2, 1, [0]), ts.TimeStep(ts.StepType.MID, 3, 1, [1]), ts.TimeStep(ts.StepType.MID, 5, 1, [2]), ts.TimeStep(ts.StepType.LAST, 7, 1, [3]), ] return mock_env
def test_load_gym_render_kwargs(self): env = suite_gym.load('CartPole-v1', render_kwargs={'width': 96, 'height': 128}) gym_env = env.gym self.assertIsInstance(gym_env, gym.Env) gym_env.render = mock.MagicMock() # render_kwargs should be passed to the underlying gym env's render(). env.render() gym_env.render.assert_called_with('rgb_array', width=96, height=128)
def testKlPenaltyLoss(self): actor_net = actor_distribution_network.ActorDistributionNetwork( self._time_step_spec.observation, self._action_spec, fc_layer_params=None) value_net = value_network.ValueNetwork( self._time_step_spec.observation, fc_layer_params=None) agent = ppo_agent.PPOAgent( self._time_step_spec, self._action_spec, tf.compat.v1.train.AdamOptimizer(), actor_net=actor_net, value_net=value_net, kl_cutoff_factor=5.0, adaptive_kl_target=0.1, kl_cutoff_coef=100, ) agent.kl_cutoff_loss = mock.MagicMock( return_value=tf.constant(3.0, dtype=tf.float32)) agent.adaptive_kl_loss = mock.MagicMock( return_value=tf.constant(4.0, dtype=tf.float32)) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) time_steps = ts.restart(observations, batch_size=2) action_distribution_parameters = { 'loc': tf.constant([1.0, 1.0], dtype=tf.float32), 'scale': tf.constant([1.0, 1.0], dtype=tf.float32), } current_policy_distribution, unused_network_state = DummyActorNet( self._obs_spec, self._action_spec)(time_steps.observation, time_steps.step_type, ()) weights = tf.ones_like(time_steps.discount) expected_kl_penalty_loss = 7.0 kl_penalty_loss = agent.kl_penalty_loss(time_steps, action_distribution_parameters, current_policy_distribution, weights) self.evaluate(tf.compat.v1.global_variables_initializer()) kl_penalty_loss_ = self.evaluate(kl_penalty_loss) self.assertEqual(expected_kl_penalty_loss, kl_penalty_loss_)
def _get_mock_env_step(self): mock_env = mock.MagicMock() mock_env.observation_spec.side_effect = [ array_spec.BoundedArraySpec((3,), np.int32, -10, 10), array_spec.BoundedArraySpec((3,), np.int32, -10, 10), array_spec.BoundedArraySpec((3,), np.int32, -10, 10), ] mock_env.reset.side_effect = [ts.TimeStep(ts.StepType.MID, 5, 1, [3, 5, 2])] mock_env.step.side_effect = [ts.TimeStep(ts.StepType.MID, 5, 1, [1, 2, 3])] return mock_env
def get_mock_env(action_spec, observation_spec, step_return): env = mock.MagicMock() env.observation_spec = lambda: observation_spec time_step_spec = ts.time_step_spec(observation_spec) env.time_step_spec = lambda: time_step_spec env.action_spec = lambda: action_spec env.step = lambda: step_return env.step.reset = lambda: step_return return env
def _get_mock_env_episode(self): mock_env = mock.MagicMock() mock_env.step.side_effect = [ # In practice, the first reward would be 0, but test with a reward of 1. ts.TimeStep(ts.StepType.FIRST, 1, 1, [0]), ts.TimeStep(ts.StepType.MID, 2, 1, [1]), ts.TimeStep(ts.StepType.MID, 3, 1, [2]), ts.TimeStep(ts.StepType.MID, 5, 1, [3]), ts.TimeStep(ts.StepType.LAST, 7, 1, [4]), ] return mock_env
def testActionSelection(self): policy = py_epsilon_greedy_policy.EpsilonGreedyPolicy( self.greedy_policy, 0.9, random_policy=self.random_policy) time_step = mock.MagicMock() # Replace the random generator with fixed behaviour random = mock.MagicMock() policy._rng = random # 0.8 < 0.9, so random policy should be used. policy._rng.rand.return_value = 0.8 policy.action(time_step) self.random_policy.action.assert_called_with(time_step) self.assertEqual(1, self.random_policy.action.call_count) self.assertEqual(0, self.greedy_policy.action.call_count) # 0.91 > 0.9, so greedy policy should be used. policy._rng.rand.return_value = 0.91 policy.action(time_step) self.greedy_policy.action.assert_called_with(time_step, policy_state=()) self.assertEqual(1, self.random_policy.action.call_count) self.assertEqual(1, self.greedy_policy.action.call_count)
def test_method_propagation(self): cartpole_env = gym.spec('CartPole-v1').make() for method_name in ('render', 'seed', 'close'): setattr(cartpole_env, method_name, mock.MagicMock()) env = gym_wrapper.GymWrapper(cartpole_env) env.render() self.assertEqual(1, cartpole_env.render.call_count) env.seed(0) self.assertEqual(1, cartpole_env.seed.call_count) cartpole_env.seed.assert_called_with(0) env.close() self.assertEqual(1, cartpole_env.close.call_count)
def setUp(self, mocked_hook): super(JSONGoogleCloudStorageHookTest, self).setUp() self.addCleanup(mock.patch.stopall) mocked_hook.return_value = mock.MagicMock(gcp_conn_id='test_conn', autospec=True) self.gcs_hook = gcs_hook.GoogleCloudStorageHook( bucket='bucket', content_type=gcs_hook.BlobContentTypes.JSON.name, prefix='') self.mocked_conn = mock.patch.object(base_hook.GoogleCloudStorageHook, 'get_conn', autospec=True).start() self.mocked_conn.return_value.objects = mock.MagicMock() self.mocked_list = mock.patch.object(base_hook.GoogleCloudStorageHook, 'list', autospec=True).start() self.patched_chunk_generator = mock.patch.object( gcs_hook.GoogleCloudStorageHook, '_gcs_blob_chunk_generator', autospec=True).start()
def test_wrapped_method_propagation(self): mock_env = mock.MagicMock() env = wrappers.PyEnvironmentBaseWrapper(mock_env) env.reset() self.assertEqual(1, mock_env.reset.call_count) env.step(0) self.assertEqual(1, mock_env.step.call_count) mock_env.step.assert_called_with(0) env.seed(0) self.assertEqual(1, mock_env.seed.call_count) mock_env.seed.assert_called_with(0) env.render() self.assertEqual(1, mock_env.render.call_count) env.close() self.assertEqual(1, mock_env.close.call_count)
def test_logged_retry_on_retriable_http_airflow_exception_retry_on_status( self, status): error = exceptions.AirflowException(f'{status}:Error') mocked_decorated_function_that_throws_error = mock.MagicMock( side_effect=error) decorated_func = retry_utils \ .logged_retry_on_retriable_http_airflow_exception( mocked_decorated_function_that_throws_error) try: decorated_func() except exceptions.AirflowException: pass self.assertEqual(mocked_decorated_function_that_throws_error.call_count, retry_utils._RETRY_UTILS_MAX_RETRIES)
def test_wrapped_method_propagation(self): mock_env = mock.MagicMock() env = alf_wrappers.AlfEnvironmentBaseWrapper(mock_env) env.reset() self.assertEqual(1, mock_env.reset.call_count) action = np.array(0, dtype=np.int64) env.step(action) self.assertEqual(1, mock_env.step.call_count) mock_env.step.assert_called_with(0) env.seed(0) self.assertEqual(1, mock_env.seed.call_count) mock_env.seed.assert_called_with(0) env.render() self.assertEqual(1, mock_env.render.call_count) env.close() self.assertEqual(1, mock_env.close.call_count)
def setUpClass(cls): super(KeywordClusteringTest, cls).setUpClass() cls.phrase = "the hello world" cls.phrase_embedding = np.full((2, 50), 0.5) cls.phrase_embedding_avg = np.full(50, 0.5) cls.model = mock.MagicMock( side_effect=lambda x: tf.constant(0.5, shape=(len(x), 50))) cls.cluster_output_shape = (11, 5) cls.cluster_description_output_shape = (2, 2) cls.kw_clustering = keyword_clustering.KeywordClustering( model=cls.model) contents = importlib_resources.read_text(preprocess_data, "example_cluster_df.txt") contents = contents.split("\n")[1] cls.test_df = pd.read_json(contents)
def test_resets_after_limit(self): max_steps = 5 base_env = mock.MagicMock() wrapped_env = atari_wrappers.AtariTimeLimit(base_env, max_steps) base_env.gym.game_over = False base_env.reset.return_value = ts.restart(1) base_env.step.return_value = ts.transition(2, 0) action = 1 for _ in range(max_steps + 1): wrapped_env.step(action) self.assertTrue(wrapped_env.game_over) self.assertEqual(1, base_env.reset.call_count) wrapped_env.step(action) self.assertFalse(wrapped_env.game_over) self.assertEqual(2, base_env.reset.call_count)