Exemple #1
0
        self.assertAllEqual(content[2], np.array([[2], [3]], dtype=np.int32))

    def _assert_nested_variable_updated(
            self,
            variables: types.NestedVariable,
            check_nest_seq_types: bool = True) -> None:
        # Prepare the exptected content of the variables.
        expected_values = (tf.constant(0, dtype=tf.int64, shape=()), {
            'var1': (tf.constant([1, 1], dtype=tf.float64, shape=(2, )), ),
            'var2':
            tf.constant([[2], [3]], dtype=tf.int32, shape=(2, 1))
        })
        flat_expected_values = tf.nest.flatten(expected_values)

        # Assert that the variables have the same content as the expected values.
        # Meaning that the two nested structure have to be the same.
        self.assertIsNone(
            nest_utils.assert_same_structure(variables,
                                             expected_values,
                                             check_types=check_nest_seq_types))
        # And the values in `variables` have to be equal to (or close to, depending
        # on the component type) to the expected ones.
        flat_variables = tf.nest.flatten(variables)
        self.assertAllEqual(flat_variables[0], flat_expected_values[0])
        self.assertAllClose(flat_variables[1], flat_expected_values[1])
        self.assertAllEqual(flat_variables[2], flat_expected_values[2])


if __name__ == '__main__':
    multiprocessing.handle_test_main(test_utils.main)
    def _step(self, action):
        return ()


class MockEnvironmentCrashInStep(random_py_environment.RandomPyEnvironment):
    """Raise an error after specified number of steps in an episode."""
    def __init__(self, crash_at_step):
        super(MockEnvironmentCrashInStep,
              self).__init__(array_spec.ArraySpec((3, 3), np.float32),
                             array_spec.BoundedArraySpec([1],
                                                         np.float32,
                                                         minimum=-1.0,
                                                         maximum=1.0),
                             episode_end_probability=0,
                             min_duration=crash_at_step + 1,
                             max_duration=crash_at_step + 1)
        self._crash_at_step = crash_at_step
        self._steps = 0

    def _step(self, *args, **kwargs):
        transition = super(MockEnvironmentCrashInStep,
                           self)._step(*args, **kwargs)
        self._steps += 1
        if self._steps == self._crash_at_step:
            raise RuntimeError()
        return transition


if __name__ == '__main__':
    multiprocessing.handle_test_main(tf.test.main)
Exemple #3
0
            direction_fc=1,
            adversary_env_rnn=False,
            adv_actor_fc_layers=(2, ),
            adv_value_fc_layers=(2, ),
            adv_lstm_size=(2, ),
            adv_conv_filters=2,
            adv_conv_kernel=3,
            adv_timestep_fc=1,
            num_train_steps=3,
            collect_episodes_per_iteration=2,
            num_parallel_envs=2,
            replay_buffer_capacity=401,
            num_epochs=2,
            num_eval_episodes=1,
            eval_interval=10,
            train_checkpoint_interval=500,
            policy_checkpoint_interval=500,
            log_interval=500,
            summary_interval=500,
            debug_summaries=False,
            summarize_grads_and_vars=False)
        train_exists = tf.io.gfile.exists(os.path.join(root_dir, 'train'))
        self.assertTrue(train_exists)
        saved_policies = tf.io.gfile.listdir(
            os.path.join(root_dir, 'policy_saved_model'))
        self.assertGreaterEqual(len(saved_policies), 1)


if __name__ == '__main__':
    system_multiprocessing.handle_test_main(tf.test.main)
Exemple #4
0
    def run_tests(self):
        # Import absl inside run, where dependencies have been loaded already.
        from absl import app  # pylint: disable=g-import-not-at-top

        def main(_):
            # pybullet imports multiprocessing in their setup.py, which causes an
            # issue when we import multiprocessing.pool.dummy down the line because
            # the PYTHONPATH has changed.
            for module in [
                    'multiprocessing', 'multiprocessing.pool',
                    'multiprocessing.dummy', 'multiprocessing.pool.dummy'
            ]:
                if module in sys.modules:
                    del sys.modules[module]
            # Reimport multiprocessing to avoid spurious error printouts. See
            # https://bugs.python.org/issue15881.
            import multiprocessing as _  # pylint: disable=g-import-not-at-top
            import tensorflow as tf  # pylint: disable=g-import-not-at-top

            # Sets all GPUs to 1GB of memory. The process running the bulk of the unit
            # tests allocates all GPU memory because by default TensorFlow allocates
            # all GPU memory during initialization. This causes tests in
            # run_seperately to fail with out of memory errors because they are run as
            # a subprocess of the process holding the GPU memory.
            gpus = tf.config.experimental.list_physical_devices('GPU')
            for gpu in gpus:
                tf.config.set_logical_device_configuration(
                    gpu,
                    [tf.config.LogicalDeviceConfiguration(memory_limit=1024)])

            run_separately = load_test_list('test_individually.txt')
            broken_tests = load_test_list(FLAGS.broken_tests)

            test_loader = TestLoader(exclude_list=run_separately +
                                     broken_tests)
            test_suite = test_loader.discover('tf_agents', pattern='*_test.py')
            stderr = StderrWrapper()
            result = unittest.TextTestResult(stderr,
                                             descriptions=True,
                                             verbosity=2)
            test_suite.run(result)

            external_test_failures = []

            for test in run_separately:
                filename = 'tf_agents/%s.py' % test.replace('.', '/')
                try:
                    subprocess.check_call([sys.executable, filename])
                except subprocess.CalledProcessError as e:
                    external_test_failures.append(e)

            result.printErrors()

            for failure in external_test_failures:
                stderr.writeln(str(failure))

            final_output = (
                'Tests run: {} grouped and {} external.  '.format(
                    result.testsRun, len(run_separately)) +
                'Errors: {}  Failures: {}  External failures: {}.'.format(
                    len(result.errors), len(result.failures),
                    len(external_test_failures)))

            header = '=' * len(final_output)
            stderr.writeln(header)
            stderr.writeln(final_output)
            stderr.writeln(header)

            if result.wasSuccessful() and not external_test_failures:
                return 0
            else:
                return 1

        # Run inside absl.app.run to ensure flags parsing is done.
        from tf_agents.system import system_multiprocessing as multiprocessing  # pylint: disable=g-import-not-at-top
        return multiprocessing.handle_test_main(lambda: app.run(main))
Exemple #5
0
  def run_tests(self):
    # Import absl inside run, where dependencies have been loaded already.
    from absl import app  # pylint: disable=g-import-not-at-top

    def main(_):
      # pybullet imports multiprocessing in their setup.py, which causes an
      # issue when we import multiprocessing.pool.dummy down the line because
      # the PYTHONPATH has changed.
      for module in [
          'multiprocessing', 'multiprocessing.pool', 'multiprocessing.dummy',
          'multiprocessing.pool.dummy'
      ]:
        if module in sys.modules:
          del sys.modules[module]
      # Reimport multiprocessing to avoid spurious error printouts. See
      # https://bugs.python.org/issue15881.
      import multiprocessing as _  # pylint: disable=g-import-not-at-top

      run_separately = load_test_list('test_individually.txt')
      broken_tests = load_test_list('broken_tests.txt')

      test_loader = TestLoader(blacklist=run_separately + broken_tests)
      test_suite = test_loader.discover('tf_agents', pattern='*_test.py')
      stderr = StderrWrapper()
      result = unittest.TextTestResult(stderr, descriptions=True, verbosity=2)
      test_suite.run(result)

      external_test_failures = []

      for test in run_separately:
        filename = 'tf_agents/%s.py' % test.replace('.', '/')
        try:
          subprocess.check_call([sys.executable, filename])
        except subprocess.CalledProcessError as e:
          external_test_failures.append(e)

      result.printErrors()

      for failure in external_test_failures:
        stderr.writeln(str(failure))

      final_output = (
          'Tests run: {} grouped and {} external.  '.format(
              result.testsRun, len(run_separately)) +
          'Errors: {}  Failures: {}  External failures: {}.'.format(
              len(result.errors),
              len(result.failures),
              len(external_test_failures)))

      header = '=' * len(final_output)
      stderr.writeln(header)
      stderr.writeln(final_output)
      stderr.writeln(header)

      if result.wasSuccessful() and not external_test_failures:
        return 0
      else:
        return 1

    # Run inside absl.app.run to ensure flags parsing is done.
    from tf_agents.system import system_multiprocessing as multiprocessing  # pylint: disable=g-import-not-at-top
    return multiprocessing.handle_test_main(lambda: app.run(main))
Exemple #6
0
        ctx = multiprocessing.get_context()

        # Local function should easily access _XVAL
        local_queue = ctx.SimpleQueue()
        execute_pickled_fn(serialized_get_xval, local_queue)
        self.assertFalse(local_queue.empty())
        self.assertEqual(local_queue.get(), 2)

        # Remote function can access new _XVAL since part of running it
        # is serializing the state via XValStateSaver (passed to handle_test_main
        # below).
        remote_queue = ctx.SimpleQueue()
        p = ctx.Process(target=execute_pickled_fn,
                        args=(serialized_get_xval, remote_queue))
        p.start()
        p.join()
        self.assertFalse(remote_queue.empty())
        self.assertEqual(remote_queue.get(), 2)

    def testPool(self):
        ctx = multiprocessing.get_context()
        p = ctx.Pool(3)
        x = 1
        values = p.map(x.__add__, [3, 4, 5, 6, 6])
        self.assertEqual(values, [4, 5, 6, 7, 7])


if __name__ == '__main__':
    multiprocessing.handle_test_main(test_utils.main,
                                     extra_state_savers=[XValStateSaver()])