def test_conflicting_specs(self): # If two different specs are encountered an error should be thrown self.other_data_spec = SimpleSpec( step_type=tf.TensorSpec(shape=(1), dtype=tf.int32, name="step_type"), value=tf.TensorSpec(shape=(1, 5), dtype=tf.float64, name="value")) self.other_dataset_path = os.path.join( self.get_temp_dir(), "other_test_tfrecord_dataset.tfrecord") example_encoding_dataset.encode_spec_to_file(self.other_dataset_path, self.other_data_spec) with self.assertRaises(IOError): example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path, self.other_dataset_path])
def create_dataset_for_given_tfrecord( record_path, load_buffer_size=None, shuffle_buffer_size=1, num_parallel_reads=100, decoder=None, ): """Creates TFDataset based on a given TFRecord file. TFRecords contain agent experience (as saved by the tf_agents endpoints). Args: record_path: path to TFRecord load_buffer_size: Buffer size shuffle_buffer_size: Size of shuffle num_parallel_reads: How many workers to use for loading decoder: decoder Returns: dataset: dataset """ dataset = example_encoding_dataset.load_tfrecord_dataset( record_path, buffer_size=load_buffer_size, as_trajectories=True, as_experience=True, add_batch_dim=False, compress_image=True, num_parallel_reads=num_parallel_reads, decoder=decoder, ) # Add dummy info field for experience. if shuffle_buffer_size > 1: dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def test_with_dynamic_step_driver(self): env = driver_test_utils.PyEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(), tf_env.action_spec()) trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(), policy.policy_step_spec, tf_env.time_step_spec()) tfrecord_observer = example_encoding_dataset.TFRecordObserver( self.dataset_path, trajectory_spec) driver = dynamic_step_driver.DynamicStepDriver( tf_env, policy, observers=[common.function(tfrecord_observer)], num_steps=10) self.evaluate(tf.compat.v1.global_variables_initializer()) time_step = self.evaluate(tf_env.reset()) initial_policy_state = policy.get_initial_state(batch_size=1) self.evaluate( common.function(driver.run)(time_step, initial_policy_state)) tfrecord_observer.flush() tfrecord_observer.close() dataset = example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path], buffer_size=2, as_trajectories=True) iterator = eager_utils.dataset_iterator(dataset) sample = self.evaluate(eager_utils.get_next(iterator)) self.assertIsInstance(sample, trajectory.Trajectory)
def test_with_py_driver(self): env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) trajectory_spec = trajectory.from_transition(env.time_step_spec(), policy.policy_step_spec, env.time_step_spec()) trajectory_spec = tensor_spec.from_spec(trajectory_spec) tfrecord_observer = example_encoding_dataset.TFRecordObserver( self.dataset_path, trajectory_spec, py_mode=True) driver = py_driver.PyDriver(env, policy, [tfrecord_observer], max_steps=10) time_step = env.reset() driver.run(time_step) tfrecord_observer.flush() tfrecord_observer.close() dataset = example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path], buffer_size=2, as_trajectories=True) iterator = eager_utils.dataset_iterator(dataset) sample = self.evaluate(eager_utils.get_next(iterator)) self.assertIsInstance(sample, trajectory.Trajectory)
def create_single_tf_record_dataset( filename: Text, load_buffer_size: int = 0, shuffle_buffer_size: int = 10000, num_parallel_reads: Optional[int] = None, decoder: Optional[DecoderFnType] = None, reward_shift: float = 0.0, action_clipping: Optional[Tuple[float, float]] = None, use_trajectories: bool = True, ): """Create a TF dataset for a single TFRecord file. Args: filename: Path to a single TFRecord file. load_buffer_size: Number of bytes in the read buffer. 0 means no buffering. shuffle_buffer_size: Size of the buffer for shuffling items within a single TFRecord file. num_parallel_reads: Optional, number of parallel reads in the TFRecord dataset. If not specified, no parallelism. decoder: Optional, a custom decoder to use rather than using the default spec path. reward_shift: Value to add to the reward. action_clipping: Optional, (minimum, maximum) values to clip actions. use_trajectories: Whether to use trajectories. If false, use transitions. Returns: A TF.Dataset of experiences. """ dataset = example_encoding_dataset.load_tfrecord_dataset( filename, buffer_size=load_buffer_size, as_experience=use_trajectories, as_trajectories=use_trajectories, add_batch_dim=False, num_parallel_reads=num_parallel_reads, decoder=decoder, ) def _sample_to_experience(sample): dummy_info = () return updated_sample(sample, reward_shift, action_clipping, use_trajectories), dummy_info dataset = dataset.map(_sample_to_experience, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def test_load_tfrecord_dataset(self): # Make sure an example tfrecord file exists before attempting to load self.test_tfrecord_observer() example_encoding_dataset.load_tfrecord_dataset([self.dataset_path], buffer_size=2) example_encoding_dataset.load_tfrecord_dataset([self.dataset_path], buffer_size=2, as_experience=True) with self.assertRaises(IOError): example_encoding_dataset.load_tfrecord_dataset(["fake_file.tfrecord"])
def load_trajs(dataset_path, batch_size, img_pad): traj_dataset = example_encoding_dataset.load_tfrecord_dataset( [dataset_path], buffer_size=int(1e6), as_experience=False, as_trajectories=True, add_batch_dim=False) logging.info('Traj dataset loaded from %s', dataset_path) traj_dataset = traj_dataset.shuffle(10000).repeat() image_aug_fn = lambda img: image_aug(img, img_pad) traj_dataset = traj_dataset.map( image_aug_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) return traj_dataset.batch(batch_size).prefetch( tf.data.experimental.AUTOTUNE)
def load_episodes(dataset_path, img_pad, gamma=0.99): """Load episode data from a fixed dataset.""" episode_dataset = example_encoding_dataset.load_tfrecord_dataset( [dataset_path], buffer_size=None, as_experience=False, add_batch_dim=False) logging.info('Episode dataset loaded from %s', dataset_path) if dataset_path.endswith('episodes'): process_episode_fn = lambda x, y: process_episode(x, y, gamma) episode_dataset = episode_dataset.map( process_episode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) episode_dataset = episode_dataset.shuffle(50).repeat() transform_episodes_fn = lambda x, y, z: transform_episodes(x, y, z, img_pad) episode_dataset = episode_dataset.map(transform_episodes_fn) episode_dataset = episode_dataset.prefetch(tf.data.experimental.AUTOTUNE) logging.info('Episode dataset processed.') return episode_dataset