Esempio n. 1
0
    def test_create_episode_dataset(self):
        d4rl_dataset = {
            'observations': [[1., 2.], [3., 4.], [5., 6.], [7., 8.], [9., 10.],
                             [11., 12.], [13., 14.]],
            'actions': [[1.], [2.], [3.], [4.], [5.], [6.], [7.]],
            'rewards': [[0.], [1.], [0.], [1.], [0.], [0.], [1.]],
            'terminals': [False, True, False, True, False, False, True],
            'timeouts': [False, False, False, False, False, False, False],
            'infos/goal': [[0.]] * 8
        }

        episode_dict = create_episode_dataset(d4rl_dataset,
                                              exclude_timeouts=True)
        expected_dict = {
            'states':
            np.array([[1., 2.], [3., 4.], [5., 6.], [7., 8.], [9., 10.],
                      [11., 12.], [13., 14.]]),
            'actions':
            np.array([[1.], [2.], [3.], [4.], [5.], [6.], [7.]]),
            'rewards':
            np.array([[0.], [1.], [0.], [1.], [0.], [0.], [1.]]),
            'discounts':
            np.array([1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]),
            'episode_start_index':
            np.array([0, 2, 4])
        }
        self.assertDictEqual(episode_dict, expected_dict)
Esempio n. 2
0
def main(_):
    logging.set_verbosity(logging.INFO)

    d4rl_env = gym.make(FLAGS.env_name)
    d4rl_dataset = d4rl_env.get_dataset()
    root_dir = os.path.join(FLAGS.root_dir, FLAGS.env_name)

    dataset_dict = dataset_utils.create_episode_dataset(
        d4rl_dataset, FLAGS.exclude_timeouts)
    num_episodes = len(dataset_dict['episode_start_index'])
    logging.info('Found %d episodes, %s total steps.', num_episodes,
                 len(dataset_dict['states']))

    collect_data_spec = dataset_utils.create_collect_data_spec(
        dataset_dict, use_trajectories=FLAGS.use_trajectories)
    logging.info('Collect data spec %s', collect_data_spec)

    num_replicas = FLAGS.replicas or 1
    interval_size = num_episodes // num_replicas + 1

    # If FLAGS.replica_id is set, only run that section of the dataset.
    # This is useful if distributing the replicas on Borg.
    if FLAGS.replica_id is not None:
        file_name = '%s_%d.tfrecord' % (FLAGS.env_name, FLAGS.replica_id)
        start_index = FLAGS.replica_id * interval_size
        end_index = min((FLAGS.replica_id + 1) * interval_size, num_episodes)
        file_utils.write_samples_to_tfrecord(
            dataset_dict=dataset_dict,
            collect_data_spec=collect_data_spec,
            dataset_path=os.path.join(root_dir, file_name),
            start_episode=start_index,
            end_episode=end_index,
            use_trajectories=FLAGS.use_trajectories)
    else:
        # Otherwise, parallelize with tf_agents.system.multiprocessing.
        jobs = []
        context = multiprocessing.get_context()

        for i in range(num_replicas):
            if num_replicas == 1:
                file_name = '%s.tfrecord' % FLAGS.env_name
            else:
                file_name = '%s_%d.tfrecord' % (FLAGS.env_name, i)
            dataset_path = os.path.join(root_dir, file_name)
            start_index = i * interval_size
            end_index = min((i + 1) * interval_size, num_episodes)
            kwargs = dict(dataset_dict=dataset_dict,
                          collect_data_spec=collect_data_spec,
                          dataset_path=dataset_path,
                          start_episode=start_index,
                          end_episode=end_index,
                          use_trajectories=FLAGS.use_trajectories)
            job = context.Process(target=file_utils.write_samples_to_tfrecord,
                                  kwargs=kwargs)
            job.start()
            jobs.append(job)

        for job in jobs:
            job.join()
Esempio n. 3
0
    def test_create_episode_dataset_exclude_timeout(self):
        d4rl_dataset = {
            'observations': [[1., 2.], [3., 4.]],
            'actions': [[1.], [2.]],
            'rewards': [[0.], [0.]],
            'terminals': [False, False],
            'timeouts': [False, True],
            'infos/goal': [[10., 10.], [10., 10.]]
        }

        episode_dict = create_episode_dataset(d4rl_dataset,
                                              exclude_timeouts=True)

        # Threw out timeout step.
        expected_dict = {
            'states': np.array([[1., 2.]]),
            'actions': np.array([[1.]]),
            'rewards': np.array([[0.]]),
            'discounts': np.array([1.0]),
            'episode_start_index': np.array([0])
        }
        self.assertDictEqual(episode_dict, expected_dict)
Esempio n. 4
0
    def test_create_episode_dataset_from_terminal(self):
        d4rl_dataset = {
            'observations': [[10., 10.]],
            'actions': [[5.]],
            'rewards': [[1.]],
            'terminals': [True],
            'timeouts': [False],
            'infos/goal': [[10., 10.]]
        }

        episode_dict = create_episode_dataset(d4rl_dataset,
                                              exclude_timeouts=True)

        # Immediately started at the goal. Not thrown out.
        expected_dict = {
            'states': np.array([[10., 10.]]),
            'actions': np.array([[5.]]),
            'rewards': np.array([[1.]]),
            'discounts': np.array([0.0]),
            'episode_start_index': np.array([0])
        }
        self.assertDictEqual(episode_dict, expected_dict)