Ejemplo n.º 1
0
    def network_signature(self, observation_space, action_space):
        n_actions = space_utils.max_size(action_space)
        if self._use_policy:
            return data.NetworkSignature(
                input=space_utils.signature(observation_space),
                output=(data.TensorSignature(shape=(1, )),
                        data.TensorSignature(shape=(n_actions, ))),
            )
        else:

            return data.NetworkSignature(
                input=space_utils.signature(observation_space),
                output=data.TensorSignature(shape=(1, )),
            )
Ejemplo n.º 2
0
 def network_signature(self, observation_space, action_space):
     del action_space
     # Input: observation, output: scalar value.
     return data.NetworkSignature(
         input=space_utils.signature(observation_space),
         output=data.TensorSignature(shape=(1,)),
     )
Ejemplo n.º 3
0
 def network_signature(self, observation_space, action_space):
     return {
         data.AgentRequest: data.NetworkSignature(
             input=space.signature(observation_space),
             output=data.TensorSignature(shape=(1,)),
         ),
         data.ModelRequest: data.NetworkSignature(
             input={
                 'observation': space.signature(observation_space),
                 'action': data.TensorSignature(
                     shape=(space.max_size(action_space),)
                 ),
             },
             output={
                 'next_observation': space.signature(observation_space),
                 'reward': data.TensorSignature(shape=(1,)),
                 'done': data.TensorSignature(shape=(1,)),
             },
         )
     }
Ejemplo n.º 4
0
    def network_signature(self, observation_space, action_space):
        n_actions = space_utils.max_size(action_space)
        action_vector_sig = data.TensorSignature(shape=(n_actions, ))
        if self._use_policy:
            output_sig = (action_vector_sig, ) * 2
        else:
            output_sig = action_vector_sig

        return data.NetworkSignature(
            input=space_utils.signature(observation_space),
            output=output_sig,
        )
Ejemplo n.º 5
0
    def network_signature(self, observation_space, action_space):
        obs_sig = space_utils.signature(observation_space)
        if self._inject_log_temperature:
            input_sig = (obs_sig, data.TensorSignature(shape=(1,)))
        else:
            input_sig = obs_sig

        n_actions = space_utils.max_size(action_space)
        action_vector_sig = data.TensorSignature(shape=(n_actions,))
        output_sig = action_vector_sig

        return data.NetworkSignature(input=input_sig, output=output_sig)
Ejemplo n.º 6
0
def test_signature_for_tuples():
    """Test generation of signatures for observation that are gym Tuples"""
    observation_space = Tuple((Tuple(
        (Box(np.array([-2, -2]),
             np.array([2, 2])), Discrete(3))), Tuple(
                 (Discrete(4), Discrete(5)))))
    observation_space_signature = signature(observation_space)
    assert observation_space_signature == ((TensorSignature(
        shape=(2, ), dtype=np.dtype('float32')),
                                            TensorSignature(
                                                shape=(),
                                                dtype=np.dtype('int64'))),
                                           (TensorSignature(
                                               shape=(),
                                               dtype=np.dtype('int64')),
                                            TensorSignature(
                                                shape=(),
                                                dtype=np.dtype('int64'))))
Ejemplo n.º 7
0
 def network_signature(self, observation_space, action_space):
     return data.NetworkSignature(
         input=space_utils.signature(observation_space),
         output=self.distribution.params_signature(action_space),
     )
Ejemplo n.º 8
0
 def network_signature(self, observation_space, action_space):
     return data.NetworkSignature(
         input=space_utils.signature(observation_space),
         output=(data.TensorSignature(shape=(1, )),
                 self.distribution.params_signature(action_space)))
Ejemplo n.º 9
0
 def network_signature(observation_space, action_space):
     del action_space
     return data.NetworkSignature(
         input=space_utils.signature(observation_space),
         output=data.TensorSignature(shape=(1, )),
     )
Ejemplo n.º 10
0
 def network_signature(observation_space, action_space):
     # print("MCTS sign")
     return data.NetworkSignature(
         input=space_utils.signature(observation_space),
         output=data.TensorSignature(shape=(action_space.n, )),
     )