def parse_encoded_spec_from_file(input_path):
    """Returns the tensor data spec stored at a path.

  Args:
    input_path: The path to the TFRecord file which contains the spec.

  Returns:
    `TensorSpec` nested structure parsed from the TFRecord file.
  Raises:
    IOError: File at input path does not exist.
  """
    if not tf.io.gfile.exists(input_path):
        raise IOError('Could not find spec file at %s.' % input_path)
    dataset = tf.data.TFRecordDataset(input_path, buffer_size=1)
    dataset_iterator = eager_utils.dataset_iterator(dataset)
    signature_proto_string = eager_utils.get_next(dataset_iterator)
    if tf.executing_eagerly():
        signature_proto = struct_pb2.StructuredValue.FromString(
            signature_proto_string.numpy())
    else:
        # In non-eager mode a session must be run in order to get the value
        with tf.Session() as sess:
            signature_proto_string_value = sess.run(signature_proto_string)
        signature_proto = struct_pb2.StructuredValue.FromString(
            signature_proto_string_value)
    return tensor_spec.from_proto(signature_proto)
示例#2
0
 def testIteration(self):
     data = np.arange(100)
     ds = tf.data.Dataset.from_tensor_slices(data)
     itr = eager_utils.dataset_iterator(ds)
     for d in data:
         self.assertEqual(np.array([d]),
                          self.evaluate(eager_utils.get_next(itr)))
    def test_with_dynamic_step_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        tf_env = tf_py_environment.TFPyEnvironment(env)
        policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(),
                                                tf_env.action_spec())

        trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     tf_env.time_step_spec())

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec)
        driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            policy,
            observers=[common.function(tfrecord_observer)],
            num_steps=10)
        self.evaluate(tf.compat.v1.global_variables_initializer())

        time_step = self.evaluate(tf_env.reset())
        initial_policy_state = policy.get_initial_state(batch_size=1)
        self.evaluate(
            common.function(driver.run)(time_step, initial_policy_state))
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)
        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
    def test_with_py_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        trajectory_spec = trajectory.from_transition(env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     env.time_step_spec())
        trajectory_spec = tensor_spec.from_spec(trajectory_spec)

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec, py_mode=True)

        driver = py_driver.PyDriver(env,
                                    policy, [tfrecord_observer],
                                    max_steps=10)
        time_step = env.reset()
        driver.run(time_step)
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)

        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
 def _observe(self):
     context, lbl = eager_utils.get_next(self._data_iterator)
     self._previous_label.assign(self._current_label)
     self._current_label.assign(tf.reshape(lbl, shape=[self._batch_size]))
     return tf.reshape(context,
                       shape=[self._batch_size] +
                       self._time_step_spec.observation.shape)
 def _observe(self) -> types.NestedTensor:
   context, lbl = eager_utils.get_next(self._data_iterator)
   self._previous_label.assign(self._current_label)
   self._current_label.assign(tf.reshape(
       tf.cast(lbl, dtype=self._label_dtype), shape=[self._batch_size]))
   return tf.reshape(
       context,
       shape=[self._batch_size] + self._time_step_spec.observation.shape)