Пример #1
0
def get_model_network_signature(observation_space, action_space):
    """Defines the signature of the network of the model, used in model-based
    experiments with trainable model.

    Args:
        observation_space (gym.Space): Environment observation space.
        action_space (gym.Discrete): Environment action space.

    Returns:
        NetworkSignature: Signature of the network.
    """
    # Actions are one-hot encoded as entire layers in the input.
    input_channels = observation_space.shape[-1] + action_space.n
    input_shape = *observation_space.shape[:-1], input_channels
    return data.NetworkSignature(input=data.TensorSignature(input_shape,
                                                            dtype=np.float32),
                                 output={
                                     'next_observation':
                                     data.TensorSignature(
                                         shape=observation_space.shape,
                                         dtype=np.float32),
                                     'reward':
                                     data.TensorSignature(shape=(1, ),
                                                          dtype=np.float32),
                                     'done':
                                     data.TensorSignature(shape=(1, ),
                                                          dtype=np.float32)
                                 })
Пример #2
0
def test_integration_with_keras():
    TestTransition = collections.namedtuple('TestTransition', ['observation'])

    # Just a smoke test, that nothing errors out.
    n_transitions = 10
    obs_shape = (4, )
    network_sig = data.NetworkSignature(
        input=data.TensorSignature(shape=obs_shape),
        output=data.TensorSignature(shape=(1, )),
    )
    trainer = supervised.SupervisedTrainer(
        network_signature=network_sig,
        target=supervised.target_solved,
        batch_size=2,
        n_steps_per_epoch=3,
        replay_buffer_capacity=n_transitions,
    )
    trainer.add_episode(
        data.Episode(
            transition_batch=TestTransition(
                observation=np.zeros((n_transitions, ) + obs_shape), ),
            return_=123,
            solved=False,
        ))
    network = keras.KerasNetwork(network_signature=network_sig)
    trainer.train_epoch(network)
Пример #3
0
def test_model_valid(model_fn, input_shape, output_shape):
    network = keras_networks.KerasNetwork(
        model_fn=model_fn,
        network_signature=data.NetworkSignature(
            input=data.TensorSignature(shape=input_shape),
            output=data.TensorSignature(shape=output_shape),
        ),
    )
    batch_size = 7
    inp = np.zeros((batch_size, ) + input_shape)
    out = network.predict(inp)
    assert out.shape == (batch_size, ) + output_shape
Пример #4
0
    def network_signature(self, observation_space, action_space):
        obs_sig = space_utils.signature(observation_space)
        if self._inject_log_temperature:
            input_sig = (obs_sig, data.TensorSignature(shape=(1,)))
        else:
            input_sig = obs_sig

        n_actions = space_utils.max_size(action_space)
        action_vector_sig = data.TensorSignature(shape=(n_actions,))
        output_sig = action_vector_sig

        return data.NetworkSignature(input=input_sig, output=output_sig)
Пример #5
0
    def network_signature(self, observation_space, action_space):
        n_actions = space_utils.max_size(action_space)
        if self._use_policy:
            return data.NetworkSignature(
                input=space_utils.signature(observation_space),
                output=(data.TensorSignature(shape=(1, )),
                        data.TensorSignature(shape=(n_actions, ))),
            )
        else:

            return data.NetworkSignature(
                input=space_utils.signature(observation_space),
                output=data.TensorSignature(shape=(1, )),
            )
def test_multiple_targets():
    TestTransition = collections.namedtuple('TestTransition',
                                            ['observation', 'agent_info'])

    network_sig = data.NetworkSignature(
        input=data.TensorSignature(shape=(1, )),
        # Two outputs.
        output=(
            data.TensorSignature(shape=(1, )),
            data.TensorSignature(shape=(2, )),
        ),
    )
    trainer = supervised.SupervisedTrainer(
        network_signature=network_sig,
        # Two targets.
        target=(supervised.target_solved, supervised.target_qualities),
        batch_size=1,
        n_steps_per_epoch=1,
        replay_buffer_capacity=1,
    )
    trainer.add_episode(
        data.Episode(
            transition_batch=TestTransition(
                observation=np.zeros((1, 1)),
                agent_info={'qualities': np.zeros((1, 2))},
            ),
            return_=123,
            solved=False,
        ))

    class TestNetwork(core.DummyNetwork):
        """Mock class."""
        def train(self,
                  data_stream,
                  n_steps,
                  epoch,
                  validation_data_stream=None):
            np.testing.assert_equal(
                list(data_stream()),
                [
                    testing.zero_pytree(
                        (network_sig.input, network_sig.output),
                        shape_prefix=(1, ))
                ],
            )

            return {}

    trainer.train_epoch(TestNetwork(network_sig))
 def __init__(self, inputs, outputs, metrics=None):
     tensor_sig = data.TensorSignature(shape=(1, ))
     super().__init__(network_signature=data.NetworkSignature(
         input=tensor_sig, output=tensor_sig),
                      metrics=metrics)
     self._inputs = inputs
     self._outputs = outputs
Пример #8
0
 def network_signature(self, observation_space, action_space):
     del action_space
     # Input: observation, output: scalar value.
     return data.NetworkSignature(
         input=space_utils.signature(observation_space),
         output=data.TensorSignature(shape=(1,)),
     )
Пример #9
0
 def network_signature(self, observation_space, action_space):
     return {
         data.AgentRequest: data.NetworkSignature(
             input=space.signature(observation_space),
             output=data.TensorSignature(shape=(1,)),
         ),
         data.ModelRequest: data.NetworkSignature(
             input={
                 'observation': space.signature(observation_space),
                 'action': data.TensorSignature(
                     shape=(space.max_size(action_space),)
                 ),
             },
             output={
                 'next_observation': space.signature(observation_space),
                 'reward': data.TensorSignature(shape=(1,)),
                 'done': data.TensorSignature(shape=(1,)),
             },
         )
     }
Пример #10
0
    def network_signature(self, observation_space, action_space):
        n_actions = space_utils.max_size(action_space)
        action_vector_sig = data.TensorSignature(shape=(n_actions, ))
        if self._use_policy:
            output_sig = (action_vector_sig, ) * 2
        else:
            output_sig = action_vector_sig

        return data.NetworkSignature(
            input=space_utils.signature(observation_space),
            output=output_sig,
        )
Пример #11
0
def signature(space):
    """Returns a SpaceSignature of elements of the given space."""
    if isinstance(space, gym.spaces.Tuple):
        return tuple(signature(subspace) for subspace in space)
    else:
        return data.TensorSignature(shape=space.shape, dtype=space.dtype)
Пример #12
0
"""Tests for alpacka.trainers.replay_buffers."""

import collections

import numpy as np
import pytest

from alpacka import data
from alpacka.trainers import replay_buffers

_TestTransition = collections.namedtuple('_TestTransition', ['test_field'])

# Keep _TestTransitions with a single number in the buffer.
_test_datapoint_sig = _TestTransition(
    test_field=data.TensorSignature(shape=()), )

_TestPairTransition = collections.namedtuple('_TestPairTransition', ['a', 'b'])

_test_pair_datapoint_sig = _TestPairTransition(
    a=data.TensorSignature(shape=()),
    b=data.TensorSignature(shape=()),
)


def test_uniform_samples_added_transition():
    buf = replay_buffers.UniformReplayBuffer(_test_datapoint_sig, capacity=10)
    stacked_transitions = _TestTransition(np.array([123]))
    buf.add(stacked_transitions)
    assert buf.sample(batch_size=1) == stacked_transitions

 def __init__(self, request_type):
     tensor_sig = data.TensorSignature(shape=(1, ))
     super().__init__(network_signature=data.NetworkSignature(
         input=tensor_sig, output=tensor_sig))
     self._request_type = request_type
Пример #14
0
 def params_signature(action_space):
     return data.TensorSignature(
         shape=(space_utils.max_size(action_space), ))
Пример #15
0
def keras_mlp():
    return keras_networks.KerasNetwork(network_signature=data.NetworkSignature(
        input=data.TensorSignature(shape=(13, )),
        output=data.TensorSignature(shape=(1, )),
    ))
Пример #16
0
def signature(space):
    """Returns a TensorSignature of elements of the given space."""
    return data.TensorSignature(shape=space.shape, dtype=space.dtype)
Пример #17
0
def signature(space):
    return data.TensorSignature(shape=space.shape, dtype=space.dtype)
Пример #18
0
 def network_signature(observation_space, action_space):
     del action_space
     return data.NetworkSignature(
         input=space_utils.signature(observation_space),
         output=data.TensorSignature(shape=(1, )),
     )
Пример #19
0
 def network_signature(self, observation_space, action_space):
     return data.NetworkSignature(
         input=space_utils.signature(observation_space),
         output=(data.TensorSignature(shape=(1, )),
                 self.distribution.params_signature(action_space)))
Пример #20
0
 def network_signature(observation_space, action_space):
     # print("MCTS sign")
     return data.NetworkSignature(
         input=space_utils.signature(observation_space),
         output=data.TensorSignature(shape=(action_space.n, )),
     )