コード例 #1
0
ファイル: client_env.py プロジェクト: tianhai123/-
 def render(self, mode="rgb_array"):
     render_response = ClientEnv.run_render(self._stub, mode=mode)
     if not render_response:
         return
     # Parse out the numpy array.
     return serialization.tensor_proto_to_numpy_array(
         render_response.observation.observation)
コード例 #2
0
ファイル: client_env.py プロジェクト: tianhai123/-
 def step(self, action):
     step_response = ClientEnv.run_step(self._stub, action)
     observation = self._maybe_squeeze_array(
         serialization.tensor_proto_to_numpy_array(
             step_response.observation.observation))
     info = {k: v for k, v in step_response.info.info_map.items()}
     return observation, step_response.reward, step_response.done, info
コード例 #3
0
ファイル: client_env.py プロジェクト: tianhai123/-
 def reset(self):
     # Run the RPC.
     reset_response_proto = ClientEnv.run_reset(self._stub)
     # Convert the TensorProto to numpy.
     obs_np = serialization.tensor_proto_to_numpy_array(
         reset_response_proto.observation.observation)
     return self._maybe_squeeze_array(obs_np)
コード例 #4
0
  def test_conversion(self):
    np_a = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
    obs = utils.numpy_array_to_observation(np_a)

    tp_a = obs.observation
    np_tp_a = utils.tensor_proto_to_numpy_array(tp_a)

    np.testing.assert_array_equal(np_a, np_tp_a)