Ejemplo n.º 1
0
def test_integration_with_keras():
    TestTransition = collections.namedtuple('TestTransition', ['observation'])

    # Just a smoke test, that nothing errors out.
    n_transitions = 10
    obs_shape = (4, )
    network_sig = data.NetworkSignature(
        input=data.TensorSignature(shape=obs_shape),
        output=data.TensorSignature(shape=(1, )),
    )
    trainer = supervised.SupervisedTrainer(
        network_signature=network_sig,
        target=supervised.target_solved,
        batch_size=2,
        n_steps_per_epoch=3,
        replay_buffer_capacity=n_transitions,
    )
    trainer.add_episode(
        data.Episode(
            transition_batch=TestTransition(
                observation=np.zeros((n_transitions, ) + obs_shape), ),
            return_=123,
            solved=False,
        ))
    network = keras.KerasNetwork(network_signature=network_sig)
    trainer.train_epoch(network)
Ejemplo n.º 2
0
def test_model_valid(model_fn, input_shape, output_shape):
    network = keras_networks.KerasNetwork(model_fn=model_fn,
                                          input_shape=input_shape)
    batch_size = 7
    inp = np.zeros((batch_size, ) + input_shape)
    out = network.predict(inp)
    assert out.shape == (batch_size, ) + output_shape
Ejemplo n.º 3
0
def test_model_valid(model_fn, input_shape, output_shape):
    network = keras_networks.KerasNetwork(
        model_fn=model_fn,
        network_signature=data.NetworkSignature(
            input=data.TensorSignature(shape=input_shape),
            output=data.TensorSignature(shape=output_shape),
        ),
    )
    batch_size = 7
    inp = np.zeros((batch_size, ) + input_shape)
    out = network.predict(inp)
    assert out.shape == (batch_size, ) + output_shape
Ejemplo n.º 4
0
def test_integration_with_keras():
    # Just a smoke test, that nothing errors out.
    n_transitions = 10
    obs_shape = (4, )
    trainer = supervised.SupervisedTrainer(
        input_shape=obs_shape,
        target_fn=supervised.target_solved,
        batch_size=2,
        n_steps_per_epoch=3,
        replay_buffer_capacity=n_transitions,
    )
    trainer.add_episode(
        data.Episode(
            transition_batch=_TestTransition(
                observation=np.zeros((n_transitions, ) + obs_shape), ),
            return_=123,
            solved=False,
        ))
    network = keras.KerasNetwork(input_shape=obs_shape)
    trainer.train_epoch(network)
Ejemplo n.º 5
0
def keras_mlp():
    return keras_networks.KerasNetwork(input_shape=(13, ))
Ejemplo n.º 6
0
def keras_mlp():
    return keras_networks.KerasNetwork(network_signature=data.NetworkSignature(
        input=data.TensorSignature(shape=(13, )),
        output=data.TensorSignature(shape=(1, )),
    ))