def network_signature(self, observation_space, action_space): n_actions = space_utils.max_size(action_space) if self._use_policy: return data.NetworkSignature( input=space_utils.signature(observation_space), output=(data.TensorSignature(shape=(1, )), data.TensorSignature(shape=(n_actions, ))), ) else: return data.NetworkSignature( input=space_utils.signature(observation_space), output=data.TensorSignature(shape=(1, )), )
def test_integration_with_keras(): TestTransition = collections.namedtuple('TestTransition', ['observation']) # Just a smoke test, that nothing errors out. n_transitions = 10 obs_shape = (4, ) network_sig = data.NetworkSignature( input=data.TensorSignature(shape=obs_shape), output=data.TensorSignature(shape=(1, )), ) trainer = supervised.SupervisedTrainer( network_signature=network_sig, target=supervised.target_solved, batch_size=2, n_steps_per_epoch=3, replay_buffer_capacity=n_transitions, ) trainer.add_episode( data.Episode( transition_batch=TestTransition( observation=np.zeros((n_transitions, ) + obs_shape), ), return_=123, solved=False, )) network = keras.KerasNetwork(network_signature=network_sig) trainer.train_epoch(network)
def __init__(self, inputs, outputs, metrics=None): tensor_sig = data.TensorSignature(shape=(1, )) super().__init__(network_signature=data.NetworkSignature( input=tensor_sig, output=tensor_sig), metrics=metrics) self._inputs = inputs self._outputs = outputs
def get_model_network_signature(observation_space, action_space): """Defines the signature of the network of the model, used in model-based experiments with trainable model. Args: observation_space (gym.Space): Environment observation space. action_space (gym.Discrete): Environment action space. Returns: NetworkSignature: Signature of the network. """ # Actions are one-hot encoded as entire layers in the input. input_channels = observation_space.shape[-1] + action_space.n input_shape = *observation_space.shape[:-1], input_channels return data.NetworkSignature(input=data.TensorSignature(input_shape, dtype=np.float32), output={ 'next_observation': data.TensorSignature( shape=observation_space.shape, dtype=np.float32), 'reward': data.TensorSignature(shape=(1, ), dtype=np.float32), 'done': data.TensorSignature(shape=(1, ), dtype=np.float32) })
def network_signature(self, observation_space, action_space): del action_space # Input: observation, output: scalar value. return data.NetworkSignature( input=space_utils.signature(observation_space), output=data.TensorSignature(shape=(1,)), )
def network_signature(self, observation_space, action_space): return { data.AgentRequest: data.NetworkSignature( input=space.signature(observation_space), output=data.TensorSignature(shape=(1,)), ), data.ModelRequest: data.NetworkSignature( input={ 'observation': space.signature(observation_space), 'action': data.TensorSignature( shape=(space.max_size(action_space),) ), }, output={ 'next_observation': space.signature(observation_space), 'reward': data.TensorSignature(shape=(1,)), 'done': data.TensorSignature(shape=(1,)), }, ) }
def network_signature(self, observation_space, action_space): n_actions = space_utils.max_size(action_space) action_vector_sig = data.TensorSignature(shape=(n_actions, )) if self._use_policy: output_sig = (action_vector_sig, ) * 2 else: output_sig = action_vector_sig return data.NetworkSignature( input=space_utils.signature(observation_space), output=output_sig, )
def test_model_valid(model_fn, input_shape, output_shape): network = keras_networks.KerasNetwork( model_fn=model_fn, network_signature=data.NetworkSignature( input=data.TensorSignature(shape=input_shape), output=data.TensorSignature(shape=output_shape), ), ) batch_size = 7 inp = np.zeros((batch_size, ) + input_shape) out = network.predict(inp) assert out.shape == (batch_size, ) + output_shape
def network_signature(self, observation_space, action_space): obs_sig = space_utils.signature(observation_space) if self._inject_log_temperature: input_sig = (obs_sig, data.TensorSignature(shape=(1,))) else: input_sig = obs_sig n_actions = space_utils.max_size(action_space) action_vector_sig = data.TensorSignature(shape=(n_actions,)) output_sig = action_vector_sig return data.NetworkSignature(input=input_sig, output=output_sig)
def test_multiple_targets(): TestTransition = collections.namedtuple('TestTransition', ['observation', 'agent_info']) network_sig = data.NetworkSignature( input=data.TensorSignature(shape=(1, )), # Two outputs. output=( data.TensorSignature(shape=(1, )), data.TensorSignature(shape=(2, )), ), ) trainer = supervised.SupervisedTrainer( network_signature=network_sig, # Two targets. target=(supervised.target_solved, supervised.target_qualities), batch_size=1, n_steps_per_epoch=1, replay_buffer_capacity=1, ) trainer.add_episode( data.Episode( transition_batch=TestTransition( observation=np.zeros((1, 1)), agent_info={'qualities': np.zeros((1, 2))}, ), return_=123, solved=False, )) class TestNetwork(core.DummyNetwork): """Mock class.""" def train(self, data_stream, n_steps, epoch, validation_data_stream=None): np.testing.assert_equal( list(data_stream()), [ testing.zero_pytree( (network_sig.input, network_sig.output), shape_prefix=(1, )) ], ) return {} trainer.train_epoch(TestNetwork(network_sig))
def __init__(self, request_type): tensor_sig = data.TensorSignature(shape=(1, )) super().__init__(network_signature=data.NetworkSignature( input=tensor_sig, output=tensor_sig)) self._request_type = request_type
def network_signature(self, observation_space, action_space): return data.NetworkSignature( input=space_utils.signature(observation_space), output=self.distribution.params_signature(action_space), )
def network_signature(self, observation_space, action_space): return data.NetworkSignature( input=space_utils.signature(observation_space), output=(data.TensorSignature(shape=(1, )), self.distribution.params_signature(action_space)))
def keras_mlp(): return keras_networks.KerasNetwork(network_signature=data.NetworkSignature( input=data.TensorSignature(shape=(13, )), output=data.TensorSignature(shape=(1, )), ))
def network_signature(observation_space, action_space): del action_space return data.NetworkSignature( input=space_utils.signature(observation_space), output=data.TensorSignature(shape=(1, )), )
def network_signature(observation_space, action_space): # print("MCTS sign") return data.NetworkSignature( input=space_utils.signature(observation_space), output=data.TensorSignature(shape=(action_space.n, )), )