Example #1
0
    def __init__(self,
                 file_paths: List[Iterable[str]],
                 num_workers: int,
                 num_modules: int,
                 runner: Callable[[str, str, int], Tuple[str, int]],
                 parser: Callable[[List[str]],
                                  Iterator[trajectory.Trajectory]],
                 use_stale_results=False,
                 max_unfinished_tasks=None,
                 overload_handler=default_overload_handler):
        super(LocalDataCollector, self).__init__()

        self._file_paths = file_paths
        self._num_modules = num_modules
        self._runner = runner
        self._parser = parser

        self._unfinished_work = []
        self._pool = multiprocessing.get_context('spawn').Pool(num_workers)

        self._max_unfinished_tasks = max_unfinished_tasks
        if not self._max_unfinished_tasks:
            self._max_unfinished_tasks = _UNFINISHED_WORK_RATIO * num_modules
        self._use_stale_results = use_stale_results

        self._default_policy_size_map = collections.defaultdict(lambda: None)
        self._overloaded_workers_handler = overload_handler
    def testGinBindingsInOtherProcess(self):
        # Serialize a function that we will call in subprocesses
        serialized_get_xval = pickle.dumps(get_xval)

        # get_xval accesses _XVAL, we set the state to 2 and will check that
        # subprocesses will see this value.
        global _XVAL
        _XVAL = 2

        ctx = multiprocessing.get_context()

        # Local function should easily access _XVAL
        local_queue = ctx.SimpleQueue()
        execute_pickled_fn(serialized_get_xval, local_queue)
        self.assertFalse(local_queue.empty())
        self.assertEqual(local_queue.get(), 2)

        # Remote function can access new _XVAL since part of running it
        # is serializing the state via XValStateSaver (passed to handle_test_main
        # below).
        remote_queue = ctx.SimpleQueue()
        p = ctx.Process(target=execute_pickled_fn,
                        args=(serialized_get_xval, remote_queue))
        p.start()
        p.join()
        self.assertFalse(remote_queue.empty())
        self.assertEqual(remote_queue.get(), 2)
Example #3
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Initialize runner and file_suffix according to compile_task.
  if FLAGS.compile_task == 'inlining':
    runner = inlining_runner.InliningRunner(
        clang_path=FLAGS.clang_path, llvm_size_path=FLAGS.llvm_size_path)
    file_suffix = ['.bc', '.cmd']

  with open(os.path.join(FLAGS.data_path, 'module_paths'), 'r') as f:
    module_paths = [
        os.path.join(FLAGS.data_path, name.rstrip('\n')) for name in f
    ]
    file_paths = [
        tuple([p + suffix for suffix in file_suffix]) for p in module_paths
    ]

  # Sampling if needed.
  if FLAGS.sampling_rate < 1:
    sampled_modules = int(len(file_paths) * FLAGS.sampling_rate)
    file_paths = random.sample(file_paths, k=sampled_modules)

  ctx = multiprocessing.get_context()
  pool = ctx.Pool(FLAGS.num_workers)

  index = 0
  total_successful_examples = 0
  with tf.io.TFRecordWriter(FLAGS.output_path) as file_writer:
    while index < len(file_paths):
      # Shard data collection and sink to tfrecord periodically to avoid OOM.
      next_index = min(index + _BATCH_SIZE, len(file_paths))
      sharded_file_paths = file_paths[index:next_index]
      index = next_index

      results = [
          pool.apply_async(runner.collect_data, (path, '', None))
          for path in sharded_file_paths
      ]

      # Wait till all jobs finish.
      waiting_time = 0
      while True:
        if sum([not r.ready() for r in results]) == 0:
          break
        logging.info('%d/%d: %d of %d modules finished in %d seconds.', index,
                     len(file_paths), sum([r.ready() for r in results]),
                     len(sharded_file_paths), waiting_time)
        time.sleep(1)
        waiting_time += 1

      # Write successful examples to tfrecord.
      successful_count = len(
          [file_writer.write(r.get()[0]) for r in results if r.successful()])
      logging.info('%d/%d: %d of %d modules succeeded.', index, len(file_paths),
                   successful_count, len(sharded_file_paths))
      total_successful_examples += successful_count

  logging.info('%d of %d modules succeeded in total.',
               total_successful_examples, len(file_paths))
Example #4
0
def main(_):
    logging.set_verbosity(logging.INFO)

    d4rl_env = gym.make(FLAGS.env_name)
    d4rl_dataset = d4rl_env.get_dataset()
    root_dir = os.path.join(FLAGS.root_dir, FLAGS.env_name)

    dataset_dict = dataset_utils.create_episode_dataset(
        d4rl_dataset, FLAGS.exclude_timeouts)
    num_episodes = len(dataset_dict['episode_start_index'])
    logging.info('Found %d episodes, %s total steps.', num_episodes,
                 len(dataset_dict['states']))

    collect_data_spec = dataset_utils.create_collect_data_spec(
        dataset_dict, use_trajectories=FLAGS.use_trajectories)
    logging.info('Collect data spec %s', collect_data_spec)

    num_replicas = FLAGS.replicas or 1
    interval_size = num_episodes // num_replicas + 1

    # If FLAGS.replica_id is set, only run that section of the dataset.
    # This is useful if distributing the replicas on Borg.
    if FLAGS.replica_id is not None:
        file_name = '%s_%d.tfrecord' % (FLAGS.env_name, FLAGS.replica_id)
        start_index = FLAGS.replica_id * interval_size
        end_index = min((FLAGS.replica_id + 1) * interval_size, num_episodes)
        file_utils.write_samples_to_tfrecord(
            dataset_dict=dataset_dict,
            collect_data_spec=collect_data_spec,
            dataset_path=os.path.join(root_dir, file_name),
            start_episode=start_index,
            end_episode=end_index,
            use_trajectories=FLAGS.use_trajectories)
    else:
        # Otherwise, parallelize with tf_agents.system.multiprocessing.
        jobs = []
        context = multiprocessing.get_context()

        for i in range(num_replicas):
            if num_replicas == 1:
                file_name = '%s.tfrecord' % FLAGS.env_name
            else:
                file_name = '%s_%d.tfrecord' % (FLAGS.env_name, i)
            dataset_path = os.path.join(root_dir, file_name)
            start_index = i * interval_size
            end_index = min((i + 1) * interval_size, num_episodes)
            kwargs = dict(dataset_dict=dataset_dict,
                          collect_data_spec=collect_data_spec,
                          dataset_path=dataset_path,
                          start_episode=start_index,
                          end_episode=end_index,
                          use_trajectories=FLAGS.use_trajectories)
            job = context.Process(target=file_utils.write_samples_to_tfrecord,
                                  kwargs=kwargs)
            job.start()
            jobs.append(job)

        for job in jobs:
            job.join()
    def start(self, wait_to_start=True):
        """Start the process.

    Args:
      wait_to_start: Whether the call should wait for an env initialization.
    """
        mp_context = system_multiprocessing.get_context()
        self._conn, conn = mp_context.Pipe()
        self._process = mp_context.Process(target=self._worker, args=(conn, ))
        atexit.register(self.close)
        self._process.start()
        if wait_to_start:
            self.wait_start()
    def __init__(self, file_paths: List[Iterable[str]], num_workers: int,
                 num_modules: int, runner: Callable[[str, str, int],
                                                    Tuple[str, int]],
                 parser: Callable[[List[str]],
                                  Iterator[trajectory.Trajectory]]):
        super(LocalDataCollector, self).__init__()

        self._file_paths = file_paths
        self._num_modules = num_modules
        self._runner = runner
        self._parser = parser

        ctx = multiprocessing.get_context('spawn')
        if num_workers == -1:
            num_workers = ctx.cpu_count()
        self._pool = ctx.Pool(num_workers)

        self._default_policy_size_map = collections.defaultdict(lambda: None)
 def testPool(self):
     ctx = multiprocessing.get_context()
     p = ctx.Pool(3)
     x = 1
     values = p.map(x.__add__, [3, 4, 5, 6, 6])
     self.assertEqual(values, [4, 5, 6, 7, 7])
Example #8
0
 def testPoolWithClause(self):
     ctx = multiprocessing.get_context()
     with ctx.Pool(3) as p:
         res = p.map(pickleable_sqr, [1, 2, 3, 4])
     self.assertEqual(res, [1, 4, 9, 16])