def parse_encoded_spec_from_file(input_path): """Returns the tensor data spec stored at a path. Args: input_path: The path to the TFRecord file which contains the spec. Returns: `TensorSpec` nested structure parsed from the TFRecord file. Raises: IOError: File at input path does not exist. """ if not tf.io.gfile.exists(input_path): raise IOError('Could not find spec file at %s.' % input_path) dataset = tf.data.TFRecordDataset(input_path, buffer_size=1) dataset_iterator = eager_utils.dataset_iterator(dataset) signature_proto_string = eager_utils.get_next(dataset_iterator) if tf.executing_eagerly(): signature_proto = struct_pb2.StructuredValue.FromString( signature_proto_string.numpy()) else: # In non-eager mode a session must be run in order to get the value with tf.Session() as sess: signature_proto_string_value = sess.run(signature_proto_string) signature_proto = struct_pb2.StructuredValue.FromString( signature_proto_string_value) return tensor_spec.from_proto(signature_proto)
def testIteration(self): data = np.arange(100) ds = tf.data.Dataset.from_tensor_slices(data) itr = eager_utils.dataset_iterator(ds) for d in data: self.assertEqual(np.array([d]), self.evaluate(eager_utils.get_next(itr)))
def test_with_dynamic_step_driver(self): env = driver_test_utils.PyEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(), tf_env.action_spec()) trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(), policy.policy_step_spec, tf_env.time_step_spec()) tfrecord_observer = example_encoding_dataset.TFRecordObserver( self.dataset_path, trajectory_spec) driver = dynamic_step_driver.DynamicStepDriver( tf_env, policy, observers=[common.function(tfrecord_observer)], num_steps=10) self.evaluate(tf.compat.v1.global_variables_initializer()) time_step = self.evaluate(tf_env.reset()) initial_policy_state = policy.get_initial_state(batch_size=1) self.evaluate( common.function(driver.run)(time_step, initial_policy_state)) tfrecord_observer.flush() tfrecord_observer.close() dataset = example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path], buffer_size=2, as_trajectories=True) iterator = eager_utils.dataset_iterator(dataset) sample = self.evaluate(eager_utils.get_next(iterator)) self.assertIsInstance(sample, trajectory.Trajectory)
def test_with_py_driver(self): env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) trajectory_spec = trajectory.from_transition(env.time_step_spec(), policy.policy_step_spec, env.time_step_spec()) trajectory_spec = tensor_spec.from_spec(trajectory_spec) tfrecord_observer = example_encoding_dataset.TFRecordObserver( self.dataset_path, trajectory_spec, py_mode=True) driver = py_driver.PyDriver(env, policy, [tfrecord_observer], max_steps=10) time_step = env.reset() driver.run(time_step) tfrecord_observer.flush() tfrecord_observer.close() dataset = example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path], buffer_size=2, as_trajectories=True) iterator = eager_utils.dataset_iterator(dataset) sample = self.evaluate(eager_utils.get_next(iterator)) self.assertIsInstance(sample, trajectory.Trajectory)
def _observe(self): context, lbl = eager_utils.get_next(self._data_iterator) self._previous_label.assign(self._current_label) self._current_label.assign(tf.reshape(lbl, shape=[self._batch_size])) return tf.reshape(context, shape=[self._batch_size] + self._time_step_spec.observation.shape)
def _observe(self) -> types.NestedTensor: context, lbl = eager_utils.get_next(self._data_iterator) self._previous_label.assign(self._current_label) self._current_label.assign(tf.reshape( tf.cast(lbl, dtype=self._label_dtype), shape=[self._batch_size])) return tf.reshape( context, shape=[self._batch_size] + self._time_step_spec.observation.shape)