def test_integration_with_keras(): TestTransition = collections.namedtuple('TestTransition', ['observation']) # Just a smoke test, that nothing errors out. n_transitions = 10 obs_shape = (4, ) network_sig = data.NetworkSignature( input=data.TensorSignature(shape=obs_shape), output=data.TensorSignature(shape=(1, )), ) trainer = supervised.SupervisedTrainer( network_signature=network_sig, target=supervised.target_solved, batch_size=2, n_steps_per_epoch=3, replay_buffer_capacity=n_transitions, ) trainer.add_episode( data.Episode( transition_batch=TestTransition( observation=np.zeros((n_transitions, ) + obs_shape), ), return_=123, solved=False, )) network = keras.KerasNetwork(network_signature=network_sig) trainer.train_epoch(network)
def test_model_valid(model_fn, input_shape, output_shape): network = keras_networks.KerasNetwork(model_fn=model_fn, input_shape=input_shape) batch_size = 7 inp = np.zeros((batch_size, ) + input_shape) out = network.predict(inp) assert out.shape == (batch_size, ) + output_shape
def test_model_valid(model_fn, input_shape, output_shape): network = keras_networks.KerasNetwork( model_fn=model_fn, network_signature=data.NetworkSignature( input=data.TensorSignature(shape=input_shape), output=data.TensorSignature(shape=output_shape), ), ) batch_size = 7 inp = np.zeros((batch_size, ) + input_shape) out = network.predict(inp) assert out.shape == (batch_size, ) + output_shape
def test_integration_with_keras(): # Just a smoke test, that nothing errors out. n_transitions = 10 obs_shape = (4, ) trainer = supervised.SupervisedTrainer( input_shape=obs_shape, target_fn=supervised.target_solved, batch_size=2, n_steps_per_epoch=3, replay_buffer_capacity=n_transitions, ) trainer.add_episode( data.Episode( transition_batch=_TestTransition( observation=np.zeros((n_transitions, ) + obs_shape), ), return_=123, solved=False, )) network = keras.KerasNetwork(input_shape=obs_shape) trainer.train_epoch(network)
def keras_mlp(): return keras_networks.KerasNetwork(input_shape=(13, ))
def keras_mlp(): return keras_networks.KerasNetwork(network_signature=data.NetworkSignature( input=data.TensorSignature(shape=(13, )), output=data.TensorSignature(shape=(1, )), ))