def testActionSelectionWithEpsilonDecay(self):
    policy = py_epsilon_greedy_policy.EpsilonGreedyPolicy(
        self.greedy_policy, 0.9, random_policy=self.random_policy,
        epsilon_decay_end_count=10,
        epsilon_decay_end_value=0.4)
    time_step = mock.MagicMock()
    # Replace the random generator with fixed behaviour
    random = mock.MagicMock()
    policy._rng = random

    # 0.8 < 0.9 and 0.8 < 0.85, so random policy should be used.
    policy._rng.rand.return_value = 0.8
    for _ in range(2):
      policy.action(time_step)
      self.random_policy.action.assert_called_with(time_step)
    self.assertEqual(2, self.random_policy.action.call_count)
    self.assertEqual(0, self.greedy_policy.action.call_count)

    # epislon will change from [0.8 to 0.4], and greedy policy should be used
    for _ in range(8):
      policy.action(time_step)
      self.greedy_policy.action.assert_called_with(time_step, policy_state=())
    self.assertEqual(2, self.random_policy.action.call_count)
    self.assertEqual(8, self.greedy_policy.action.call_count)

    # 0.399 < 0.4, random policy should be used.
    policy._rng.rand.return_value = 0.399
    self.random_policy.reset_mock()
    for _ in range(5):
      policy.action(time_step)
      self.random_policy.action.assert_called_with(time_step)
    self.assertEqual(5, self.random_policy.action.call_count)
    # greedy policy should not be called any more
    self.assertEqual(8, self.greedy_policy.action.call_count)
Ejemplo n.º 2
0
  def test_logged_retry_on_retriable_http_error_reraises_error(self):
    error = errors.HttpError(mock.MagicMock(), _BYTES_ERROR_MESSAGE)
    mocked_decorated_function_that_throws_error = mock.MagicMock(
        side_effect=error)
    decorated_func = retry_utils.logged_retry_on_retriable_http_error(
        mocked_decorated_function_that_throws_error)

    with self.assertRaises(errors.HttpError):
      decorated_func()
 def __init__(self, data_generator=None, table_generator=None, fields=None):
   self.project_id = 'test_project'
   self.data_generator = data_generator
   self.table_generator = table_generator
   self.fields = fields
   if table_generator:
     self.service = mock.MagicMock()
     self.service.tables = mock.MagicMock()
     self.service.tables.return_value = mock.MagicMock()
     self.service.tables().list = mock.MagicMock()
     self.service.tables().list.return_value = mock.MagicMock()
     self.service.tables().list().execute = self.service_tables_list
Ejemplo n.º 4
0
 def setUp(self):
     super(EpsilonGreedyPolicyTest, self).setUp()
     self.greedy_policy = mock.MagicMock()
     self.random_policy = mock.MagicMock()
     self.greedy_policy.time_step_spec = ts.time_step_spec()
     self.greedy_policy.action_spec = ()
     self.greedy_policy.info_spec = ()
     self.greedy_policy.policy_state_spec = ()
     self.random_policy.time_step_spec = ts.time_step_spec()
     self.random_policy.action_spec = ()
     self.random_policy.info_spec = ()
     self.random_policy.policy_state_spec = ()
     self.random_policy.action.return_value = policy_step.PolicyStep(0, ())
Ejemplo n.º 5
0
    def _setup_mocks(self):
        self.trainer = train_eval_atari.TrainEval(self.get_temp_dir(),
                                                  'Pong-v0',
                                                  terminal_on_life_loss=True)

        self.trainer._env = mock.MagicMock()
        self.trainer._env.envs[0].game_over = False
        self.trainer._replay_buffer = mock.MagicMock()
        self.trainer._collect_policy = mock.MagicMock()
        action_step = policy_step.PolicyStep(action=1)
        self.trainer._collect_policy.action.return_value = action_step
        self.observer = mock.MagicMock()
        self.metric_observers = [self.observer]
Ejemplo n.º 6
0
    def setUp(self, mocked_hook):
        super(BigqueryHookTest, self).setUp()

        mocked_hook.return_value = mock.MagicMock(bigquery_conn_id='test_conn',
                                                  autospec=True)
        bq_hook.BigQueryHook._get_field = mock.MagicMock()
        bq_hook.BigQueryHook._get_field.return_value = 'test_project'

        self.hook = bq_hook.BigQueryHook(
            conn_id='test_conn',
            dataset_id='test_dataset',
            table_id='test_table',
        )
        self.hook.get_conn = mock.MagicMock()
        self.hook.get_conn.return_value = mock.MagicMock()
        self.hook.get_conn.cursor = mock.MagicMock()

        self.error_hook = bq_hook.BigQueryHook(
            conn_id='test_conn',
            dataset_id='test_dataset',
            table_id='error_table',
        )
        self.error_hook.get_conn = mock.MagicMock()
        self.error_hook.get_conn.return_value = mock.MagicMock()
        self.error_hook.get_conn.cursor = mock.MagicMock()

        bq_hook._DEFAULT_PAGE_SIZE = 1
        self.fields = [{
            'name': 'a',
            'type': 'STRING'
        }, {
            'name': 'b',
            'type': 'STRING'
        }]
Ejemplo n.º 7
0
  def test_logged_retry_on_retriable_http_error_no_retry_on_status(self,
                                                                   status):
    error = errors.HttpError(mock.MagicMock(status=status),
                             _BYTES_ERROR_MESSAGE)
    mocked_decorated_function_that_throws_error = mock.MagicMock(
        side_effect=error)
    decorated_func = retry_utils.logged_retry_on_retriable_http_error(
        mocked_decorated_function_that_throws_error)

    try:
      decorated_func()
    except errors.HttpError:
      pass

    mocked_decorated_function_that_throws_error.assert_called_once()
Ejemplo n.º 8
0
  def test_logged_retry_on_retriable_http_error_retry_on_status(self, status):
    error = errors.HttpError(mock.MagicMock(status=status),
                             _BYTES_ERROR_MESSAGE)
    mocked_decorated_function_that_throws_error = mock.MagicMock(
        side_effect=error)
    decorated_func = retry_utils.logged_retry_on_retriable_http_error(
        mocked_decorated_function_that_throws_error)

    try:
      decorated_func()
    except errors.HttpError:
      pass

    self.assertEqual(mocked_decorated_function_that_throws_error.call_count,
                     retry_utils._RETRY_UTILS_MAX_RETRIES)
Ejemplo n.º 9
0
    def test_tpu_strategy(self, mock_tpu_cluster_resolver,
                          mock_experimental_connect_to_cluster,
                          mock_initialize_tpu_system, mock_tpu_strategy):
        resolver = mock.MagicMock()
        mock_tpu_cluster_resolver.return_value = resolver
        mock_strategy = mock.MagicMock()
        mock_tpu_strategy.return_value = mock_strategy

        strategy = strategy_utils.get_strategy(tpu='bns_address',
                                               use_gpu=False)

        mock_tpu_cluster_resolver.assert_called_with(tpu='bns_address')
        mock_experimental_connect_to_cluster.assert_called_with(resolver)
        mock_initialize_tpu_system.assert_called_with(resolver)
        self.assertIs(strategy, mock_strategy)
Ejemplo n.º 10
0
  def setUp(self):
    """Setup function for each unit test."""

    super(GoogleAnalyticsHookTest, self).setUp()

    self.test_tracking_id = 'UA-12323-4'
    self.payload_builder = ga_hook.PayloadBuilder(self.test_tracking_id)

    self.event_test_data = {
        'ec': 'ClientID',
        'ea': 'test_event_action',
        'el': '20190423',
        'ev': 1,
        'cid': '12345.456789'
    }
    self.small_event = {
        'cid': '12345.67890',
        'ec': 'ClientID',
        'ea': 'test_event_action',
        'el': '20190423',
        'ev': 1,
        'z': '1558517072202080'
    }
    # Both of the below are approx 4K of data
    self.medium_event = {**self.small_event, 'ea': 'x' * 3800}
    self.utf8_event = {**self.small_event, 'ea': b'\xf0\xa9\xb8\xbd' * 320}

    self.test_hook = ga_hook.GoogleAnalyticsHook(self.test_tracking_id,
                                                 self.event_test_data)

    self.test_hook._send_http_request = mock.MagicMock(autospec=True)
    self.test_hook._send_http_request.return_value = mock.Mock(ok=True)
    self.test_hook._send_http_request.return_value.status = 200
Ejemplo n.º 11
0
    def test_render_kwargs(self):
        cartpole_env = gym.spec('CartPole-v1').make()

        def _gym_render_mock(mode='rgb_array', **kwargs):
            del mode, kwargs  # unused
            return None

        cartpole_env.render = mock.MagicMock(side_effect=_gym_render_mock)

        # The following are hypothetical (mocked) usages of render(),
        # as the CartPole environment actually does not support kwargs for `render`.
        env = gym_wrapper.GymWrapper(cartpole_env,
                                     render_kwargs={
                                         'width': 96,
                                         'height': 128
                                     })
        env.render()
        cartpole_env.render.assert_called_with('rgb_array',
                                               width=96,
                                               height=128)
        env.render(mode='human')
        cartpole_env.render.assert_called_with('human', width=96, height=128)

        # In a corner case where kwargs includes mode (got multiple values for
        # argument 'mode'), an exception will be thrown.
        cartpole_env.render.reset_mock()
        render_kwargs = dict(mode='human', width=96, height=128)
        env = gym_wrapper.GymWrapper(cartpole_env, render_kwargs=render_kwargs)
        with self.assertRaisesRegex(TypeError, 'multiple values'):
            env.render()
Ejemplo n.º 12
0
    def test_wrapped_method_propagation(self):
        mock_env = mock.MagicMock()
        env = tf_wrappers.TFEnvironmentBaseWrapper(mock_env)

        env.time_step_spec()
        self.assertEqual(1, mock_env.time_step_spec.call_count)

        env.action_spec()
        self.assertEqual(1, mock_env.action_spec.call_count)

        env.observation_spec()
        self.assertEqual(1, mock_env.observation_spec.call_count)

        env.batched()
        self.assertEqual(1, mock_env.batched.call_count)

        env.batch_size()
        self.assertEqual(1, mock_env.batch_size.call_count)

        env.current_time_step()
        self.assertEqual(1, mock_env.current_time_step.call_count)

        env.reset()
        self.assertEqual(1, mock_env.reset.call_count)

        env.step(0)
        self.assertEqual(1, mock_env.step.call_count)
        mock_env.step.assert_called_with(0)

        env.render()
        self.assertEqual(1, mock_env.render.call_count)
Ejemplo n.º 13
0
 def test_obs_dtype(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     cartpole_env.render = mock.MagicMock()
     env = gym_wrapper.GymWrapper(cartpole_env)
     time_step = env.reset()
     self.assertEqual(env.observation_spec().dtype,
                      time_step.observation.dtype)
Ejemplo n.º 14
0
 def test_wrap_gym_render_kwargs(self):
   gym_env = gym.make('CartPole-v1')
   env = suite_gym.wrap_env(gym_env,
                            render_kwargs={'width': 96, 'height': 128})
   gym_env.render = mock.MagicMock()
   # render_kwargs should be passed to them underlying gym env's render().
   env.render()
   gym_env.render.assert_called_with('rgb_array', width=96, height=128)
Ejemplo n.º 15
0
    def test_execute_request_retries_on_service_unavailable_http_error(self):
        mock_request = mock.Mock(http.HttpRequest)
        content = b''
        error = errors.HttpError(mock.MagicMock(status=503), content)
        mock_request.execute.side_effect = [error, None]

        utils.execute_request(mock_request)

        self.assertEqual(mock_request.execute.call_count, 2)
 def testActionAlwaysRandom(self):
     policy = py_epsilon_greedy_policy.EpsilonGreedyPolicy(
         self.greedy_policy, 1, random_policy=self.random_policy)
     time_step = mock.MagicMock()
     for _ in range(5):
         policy.action(time_step)
     self.random_policy.action.assert_called_with(time_step)
     self.assertEqual(5, self.random_policy.action.call_count)
     self.assertEqual(0, self.greedy_policy.action.call_count)
Ejemplo n.º 17
0
 def _get_mock_env_episode(self):
     mock_env = mock.MagicMock()
     mock_env.step.side_effect = [
         ts.TimeStep(ts.StepType.FIRST, 2, 1, [0]),
         ts.TimeStep(ts.StepType.MID, 3, 1, [1]),
         ts.TimeStep(ts.StepType.MID, 5, 1, [2]),
         ts.TimeStep(ts.StepType.LAST, 7, 1, [3]),
     ]
     return mock_env
Ejemplo n.º 18
0
 def test_load_gym_render_kwargs(self):
   env = suite_gym.load('CartPole-v1',
                        render_kwargs={'width': 96, 'height': 128})
   gym_env = env.gym
   self.assertIsInstance(gym_env, gym.Env)
   gym_env.render = mock.MagicMock()
   # render_kwargs should be passed to the underlying gym env's render().
   env.render()
   gym_env.render.assert_called_with('rgb_array', width=96, height=128)
Ejemplo n.º 19
0
  def testKlPenaltyLoss(self):
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        self._time_step_spec.observation,
        self._action_spec,
        fc_layer_params=None)
    value_net = value_network.ValueNetwork(
        self._time_step_spec.observation, fc_layer_params=None)
    agent = ppo_agent.PPOAgent(
        self._time_step_spec,
        self._action_spec,
        tf.compat.v1.train.AdamOptimizer(),
        actor_net=actor_net,
        value_net=value_net,
        kl_cutoff_factor=5.0,
        adaptive_kl_target=0.1,
        kl_cutoff_coef=100,
    )

    agent.kl_cutoff_loss = mock.MagicMock(
        return_value=tf.constant(3.0, dtype=tf.float32))
    agent.adaptive_kl_loss = mock.MagicMock(
        return_value=tf.constant(4.0, dtype=tf.float32))

    observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    time_steps = ts.restart(observations, batch_size=2)
    action_distribution_parameters = {
        'loc': tf.constant([1.0, 1.0], dtype=tf.float32),
        'scale': tf.constant([1.0, 1.0], dtype=tf.float32),
    }
    current_policy_distribution, unused_network_state = DummyActorNet(
        self._obs_spec, self._action_spec)(time_steps.observation,
                                           time_steps.step_type, ())
    weights = tf.ones_like(time_steps.discount)

    expected_kl_penalty_loss = 7.0

    kl_penalty_loss = agent.kl_penalty_loss(time_steps,
                                            action_distribution_parameters,
                                            current_policy_distribution,
                                            weights)
    self.evaluate(tf.compat.v1.global_variables_initializer())
    kl_penalty_loss_ = self.evaluate(kl_penalty_loss)
    self.assertEqual(expected_kl_penalty_loss, kl_penalty_loss_)
Ejemplo n.º 20
0
 def _get_mock_env_step(self):
   mock_env = mock.MagicMock()
   mock_env.observation_spec.side_effect = [
       array_spec.BoundedArraySpec((3,), np.int32, -10, 10),
       array_spec.BoundedArraySpec((3,), np.int32, -10, 10),
       array_spec.BoundedArraySpec((3,), np.int32, -10, 10),
   ]
   mock_env.reset.side_effect = [ts.TimeStep(ts.StepType.MID, 5, 1, [3, 5, 2])]
   mock_env.step.side_effect = [ts.TimeStep(ts.StepType.MID, 5, 1, [1, 2, 3])]
   return mock_env
Ejemplo n.º 21
0
def get_mock_env(action_spec, observation_spec, step_return):
    env = mock.MagicMock()

    env.observation_spec = lambda: observation_spec
    time_step_spec = ts.time_step_spec(observation_spec)
    env.time_step_spec = lambda: time_step_spec
    env.action_spec = lambda: action_spec
    env.step = lambda: step_return
    env.step.reset = lambda: step_return
    return env
Ejemplo n.º 22
0
 def _get_mock_env_episode(self):
   mock_env = mock.MagicMock()
   mock_env.step.side_effect = [
       # In practice, the first reward would be 0, but test with a reward of 1.
       ts.TimeStep(ts.StepType.FIRST, 1, 1, [0]),
       ts.TimeStep(ts.StepType.MID, 2, 1, [1]),
       ts.TimeStep(ts.StepType.MID, 3, 1, [2]),
       ts.TimeStep(ts.StepType.MID, 5, 1, [3]),
       ts.TimeStep(ts.StepType.LAST, 7, 1, [4]),
   ]
   return mock_env
  def testActionSelection(self):
    policy = py_epsilon_greedy_policy.EpsilonGreedyPolicy(
        self.greedy_policy, 0.9, random_policy=self.random_policy)
    time_step = mock.MagicMock()
    # Replace the random generator with fixed behaviour
    random = mock.MagicMock()
    policy._rng = random

    # 0.8 < 0.9, so random policy should be used.
    policy._rng.rand.return_value = 0.8
    policy.action(time_step)
    self.random_policy.action.assert_called_with(time_step)
    self.assertEqual(1, self.random_policy.action.call_count)
    self.assertEqual(0, self.greedy_policy.action.call_count)

    # 0.91 > 0.9, so greedy policy should be used.
    policy._rng.rand.return_value = 0.91
    policy.action(time_step)
    self.greedy_policy.action.assert_called_with(time_step, policy_state=())
    self.assertEqual(1, self.random_policy.action.call_count)
    self.assertEqual(1, self.greedy_policy.action.call_count)
Ejemplo n.º 24
0
 def test_method_propagation(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     for method_name in ('render', 'seed', 'close'):
         setattr(cartpole_env, method_name, mock.MagicMock())
     env = gym_wrapper.GymWrapper(cartpole_env)
     env.render()
     self.assertEqual(1, cartpole_env.render.call_count)
     env.seed(0)
     self.assertEqual(1, cartpole_env.seed.call_count)
     cartpole_env.seed.assert_called_with(0)
     env.close()
     self.assertEqual(1, cartpole_env.close.call_count)
Ejemplo n.º 25
0
    def setUp(self, mocked_hook):
        super(JSONGoogleCloudStorageHookTest, self).setUp()
        self.addCleanup(mock.patch.stopall)

        mocked_hook.return_value = mock.MagicMock(gcp_conn_id='test_conn',
                                                  autospec=True)
        self.gcs_hook = gcs_hook.GoogleCloudStorageHook(
            bucket='bucket',
            content_type=gcs_hook.BlobContentTypes.JSON.name,
            prefix='')

        self.mocked_conn = mock.patch.object(base_hook.GoogleCloudStorageHook,
                                             'get_conn',
                                             autospec=True).start()
        self.mocked_conn.return_value.objects = mock.MagicMock()
        self.mocked_list = mock.patch.object(base_hook.GoogleCloudStorageHook,
                                             'list',
                                             autospec=True).start()

        self.patched_chunk_generator = mock.patch.object(
            gcs_hook.GoogleCloudStorageHook,
            '_gcs_blob_chunk_generator',
            autospec=True).start()
Ejemplo n.º 26
0
 def test_wrapped_method_propagation(self):
   mock_env = mock.MagicMock()
   env = wrappers.PyEnvironmentBaseWrapper(mock_env)
   env.reset()
   self.assertEqual(1, mock_env.reset.call_count)
   env.step(0)
   self.assertEqual(1, mock_env.step.call_count)
   mock_env.step.assert_called_with(0)
   env.seed(0)
   self.assertEqual(1, mock_env.seed.call_count)
   mock_env.seed.assert_called_with(0)
   env.render()
   self.assertEqual(1, mock_env.render.call_count)
   env.close()
   self.assertEqual(1, mock_env.close.call_count)
Ejemplo n.º 27
0
  def test_logged_retry_on_retriable_http_airflow_exception_retry_on_status(
      self, status):
    error = exceptions.AirflowException(f'{status}:Error')
    mocked_decorated_function_that_throws_error = mock.MagicMock(
        side_effect=error)
    decorated_func = retry_utils \
                     .logged_retry_on_retriable_http_airflow_exception(
                         mocked_decorated_function_that_throws_error)

    try:
      decorated_func()
    except exceptions.AirflowException:
      pass

    self.assertEqual(mocked_decorated_function_that_throws_error.call_count,
                     retry_utils._RETRY_UTILS_MAX_RETRIES)
Ejemplo n.º 28
0
 def test_wrapped_method_propagation(self):
     mock_env = mock.MagicMock()
     env = alf_wrappers.AlfEnvironmentBaseWrapper(mock_env)
     env.reset()
     self.assertEqual(1, mock_env.reset.call_count)
     action = np.array(0, dtype=np.int64)
     env.step(action)
     self.assertEqual(1, mock_env.step.call_count)
     mock_env.step.assert_called_with(0)
     env.seed(0)
     self.assertEqual(1, mock_env.seed.call_count)
     mock_env.seed.assert_called_with(0)
     env.render()
     self.assertEqual(1, mock_env.render.call_count)
     env.close()
     self.assertEqual(1, mock_env.close.call_count)
    def setUpClass(cls):
        super(KeywordClusteringTest, cls).setUpClass()
        cls.phrase = "the hello world"
        cls.phrase_embedding = np.full((2, 50), 0.5)
        cls.phrase_embedding_avg = np.full(50, 0.5)
        cls.model = mock.MagicMock(
            side_effect=lambda x: tf.constant(0.5, shape=(len(x), 50)))

        cls.cluster_output_shape = (11, 5)
        cls.cluster_description_output_shape = (2, 2)

        cls.kw_clustering = keyword_clustering.KeywordClustering(
            model=cls.model)

        contents = importlib_resources.read_text(preprocess_data,
                                                 "example_cluster_df.txt")
        contents = contents.split("\n")[1]
        cls.test_df = pd.read_json(contents)
Ejemplo n.º 30
0
    def test_resets_after_limit(self):
        max_steps = 5
        base_env = mock.MagicMock()
        wrapped_env = atari_wrappers.AtariTimeLimit(base_env, max_steps)

        base_env.gym.game_over = False
        base_env.reset.return_value = ts.restart(1)
        base_env.step.return_value = ts.transition(2, 0)
        action = 1

        for _ in range(max_steps + 1):
            wrapped_env.step(action)

        self.assertTrue(wrapped_env.game_over)
        self.assertEqual(1, base_env.reset.call_count)

        wrapped_env.step(action)
        self.assertFalse(wrapped_env.game_over)
        self.assertEqual(2, base_env.reset.call_count)