コード例 #1
0
 def testInternalAttributeIsUsed(self):
     """Tests that setter/getter properly uses the internal value."""
     MockClass = threading_utils.local_attributes(['attr'])(_DummyClass)
     obj = MockClass()
     with test_utils.mock_thread('thread'):
         obj.attr = 'internal-value'
     obj.attr = 'dummy_value'
     with test_utils.mock_thread('thread'):
         self.assertEqual(obj.attr, 'internal-value')
コード例 #2
0
 def testCallableAttribute(self):
     """Tests that internal value is properly called with callable attribute."""
     MockClass = threading_utils.local_attributes(['attr'])(_DummyClass)
     obj = MockClass()
     with test_utils.mock_thread('thread'):
         obj.attr = test.mock.Mock()
     obj.attr = test.mock.Mock()
     with test_utils.mock_thread('thread'):
         obj.attr.method()
         obj.attr.method.assert_called_once()
     obj.attr.method.assert_not_called()
コード例 #3
0
 def testMultipleAttributes(self):
     """Tests the class decorator with multiple local attributes."""
     MockClass = threading_utils.local_attributes(['attr1',
                                                   'attr2'])(_DummyClass)
     obj = MockClass()
     with test_utils.mock_thread('thread'):
         obj.attr1 = 10
         obj.attr2 = 20
     obj.attr1 = obj.attr = 'dummy_value'
     obj.attr2 = obj.attr = 'dummy_value'
     with test_utils.mock_thread('thread'):
         self.assertEqual(obj.attr1, 10)
         self.assertEqual(obj.attr2, 20)
コード例 #4
0
 def testLocalValueOverDefault(self):
     """Tests that getter uses internal value over default one."""
     MockClass = threading_utils.local_attributes(['attr'])(_DummyClass)
     obj = MockClass()
     mock_default_init = test.mock.Mock()
     threading_utils.initialize_local_attributes(obj,
                                                 attr=mock_default_init)
     with test_utils.mock_thread('thread'):
         obj.attr = 'internal-value'
     obj.attr = 'dummy_value'
     with test_utils.mock_thread('thread'):
         self.assertEqual(obj.attr, 'internal-value')
     mock_default_init.assert_not_called()
コード例 #5
0
 def testMultiThreads(self):
     """Tests that different threads create different local attributes."""
     MockClass = threading_utils.local_attributes(['attr'])(_DummyClass)
     obj = MockClass()
     # Initializes attribute in thread 1.
     with test_utils.mock_thread('thread_1'):
         obj.attr = 1
     # Initializes attribute in thread 2.
     with test_utils.mock_thread('thread_2'):
         obj.attr = 2
     # Reads attribute in thread 1.
     with test_utils.mock_thread('thread_1'):
         self.assertEqual(obj.attr, 1)
     # Reads attribute in thread 2.
     with test_utils.mock_thread('thread_2'):
         self.assertEqual(obj.attr, 2)
コード例 #6
0
 def testActionIsNotDefined(self):
     agent = dqn_agent.DQNAgent(tf.Session(), 3, observation_shape=(2, 2))
     agent.action = 'dummy-value'
     with test_utils.mock_thread('thread'):
         with self.assertRaisesRegexp(
                 AttributeError,
                 'Local value for attribute `action` has not been set.*'):
             _ = agent.action
コード例 #7
0
 def testLocalVariablesSet(self, variable_name, expected_value):
     agent = dqn_agent.DQNAgent(tf.Session(),
                                3,
                                observation_shape=(2, 2),
                                stack_size=4)
     setattr(agent, variable_name, 'dummy-value')
     with test_utils.mock_thread('thread'):
         self.assertAllEqual(getattr(agent, variable_name), expected_value)
コード例 #8
0
 def testDefaultValueIsUsed(self):
     """Tests that the default value is properly read in thread."""
     obj = _DummyClass()
     threading_utils.initialize_local_attributes(
         obj, attr=lambda: 'default-value')
     obj.attr = 'dummy_value'
     with test_utils.mock_thread('thread'):
         self.assertEqual(obj.attr, 'default-value')
コード例 #9
0
 def testMultiThreadsMultipleAttributes(self):
     """Tests that different threads create different local attributes."""
     MockClass = threading_utils.local_attributes(['attr1',
                                                   'attr2'])(_DummyClass)
     obj = MockClass()
     # Initializes attribute in thread 1.
     with test_utils.mock_thread('thread_1'):
         obj.attr1 = 1
         obj.attr2 = 2
     with test_utils.mock_thread('thread_2'):
         obj.attr1 = 3
         obj.attr2 = 4
     with test_utils.mock_thread('thread_1'):
         self.assertEqual(obj.attr1, 1)
         self.assertEqual(obj.attr2, 2)
     with test_utils.mock_thread('thread_2'):
         self.assertEqual(obj.attr1, 3)
         self.assertEqual(obj.attr2, 4)
コード例 #10
0
 def testMultipleDefaultValuesAreUsed(self):
     """Tests that multiple default values are properly set by the helper."""
     obj = _DummyClass()
     threading_utils.initialize_local_attributes(obj,
                                                 attr1=lambda: 3,
                                                 attr2=lambda: 4)
     obj.attr1 = 'dummy_value'
     obj.attr2 = 'dummy_value'
     with test_utils.mock_thread('thread'):
         self.assertEqual(obj.attr1, 3)
         self.assertEqual(obj.attr2, 4)
コード例 #11
0
    def testLocalValues(self):
        """Tests that episode related variables are thread specific."""
        with tf.Session() as sess:
            observation_shape = (2, 2)
            agent = dqn_agent.DQNAgent(sess,
                                       3,
                                       observation_shape=observation_shape)
            sess.run(tf.global_variables_initializer())

            with test_utils.mock_thread('baseline-thread'):
                agent.begin_episode(observation=np.zeros(observation_shape),
                                    training=False)
                local_values_1 = (agent._observation, agent._last_observation,
                                  agent.state)

            with test_utils.mock_thread('different-thread'):
                agent.begin_episode(observation=np.zeros(observation_shape),
                                    training=False)
                agent.step(reward=10,
                           observation=np.ones(observation_shape),
                           training=False)
                local_values_3 = (agent._observation, agent._last_observation,
                                  agent.state)

            with test_utils.mock_thread('identical-thread'):
                agent.begin_episode(observation=np.zeros(observation_shape),
                                    training=False)
                local_values_2 = (agent._observation, agent._last_observation,
                                  agent.state)

            # Asserts that values in 'identical-thread' are same as baseline.
            for val_1, val_2 in zip(local_values_1, local_values_2):
                self.assertTrue(np.all(val_1 == val_2))

            # Asserts that values in 'different-thread' are differnt from baseline.
            for val_1, val_3 in zip(local_values_1, local_values_3):
                self.assertTrue(np.any(val_1 != val_3))
コード例 #12
0
 def testAddMultipleThreadsNodeNotAdded(self):
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=1,
         replay_capacity=5,
         batch_size=BATCH_SIZE,
         use_contiguous_trajectories=True)
     self.assertEqual(memory.cursor(), 0)
     self.assertEqual(len(memory._trajectory), 0)
     zeros = np.zeros(OBSERVATION_SHAPE)
     # Add transition in main thread.
     memory.add(zeros, 0, 0, 0)
     # Add a terminal transition in separate thread.
     with test_utils.mock_thread('other-thread'):
         memory.add(zeros, 0, 0, 1)
     # Check that terminal transition is added by itself.
     self.assertEqual(memory.add_count, 1)
コード例 #13
0
    def testBundling(self):
        """Tests that local values are poperly updated when reading a checkpoint."""
        with tf.Session() as sess:
            agent = dqn_agent.DQNAgent(sess, 3, observation_shape=(2, 2))
            sess.run(tf.global_variables_initializer())
            agent.state = 'state_val'
            bundle = agent.bundle_and_checkpoint(self.get_temp_dir(),
                                                 iteration_number=10)
            self.assertIn('state', bundle)
            self.assertEqual(bundle['state'], 'state_val')
            bundle['state'] = 'new_state_val'

            with test_utils.mock_thread('other-thread'):
                agent.unbundle(self.get_temp_dir(),
                               iteration_number=10,
                               bundle_dictionary=bundle)
                self.assertEqual(agent.state, 'new_state_val')
            self.assertEqual(agent.state, 'state_val')
コード例 #14
0
    def testEnvironmentInitializationPerThread(self):
        """Tests that a new environment is created for a new thread.

    In synchronous model `create_environment_fn` is called only once at the
    runner initialization. In synchronous model, `create_environment_fn` is
    called for each new iteration.
    """
        environment_fn = _get_mock_environment_fn()
        runner = self._get_runner(create_agent_fn=test.mock.MagicMock(),
                                  create_environment_fn=environment_fn,
                                  num_iterations=1,
                                  training_steps=1,
                                  evaluation_steps=0,
                                  num_simultaneous_iterations=1)

        # Environment called once in init.
        environment_fn.assert_called_once()
        with test_utils.mock_thread('other-thread'):
            runner.run_experiment()
        runner.run_experiment()
        self.assertEqual(environment_fn.call_count, 3)
コード例 #15
0
 def testGetInternalName(self):
     """Tests that the name of the internal attribute has proper format."""
     with test_utils.mock_thread(123):
         self.assertEqual(threading_utils._get_internal_name('attr'),
                          '__attr_123')