コード例 #1
0
def test_CNN(preprocessed_images):
    batch_size = preprocessed_images.shape[0]
    history_length = preprocessed_images.shape[-1]
    nb_actions = 10
    cnn = CNN(history_length, nb_actions)
    q_values = cnn(preprocessed_images)
    assert type(q_values) == torch.Tensor
    assert list(q_values.shape) == [batch_size, nb_actions]
コード例 #2
0
ファイル: conftest.py プロジェクト: Lyp02/Deep_Q_learning
def Q():
    '''
    Generate a Q function
    '''
    agent_history_length = pytest.agent_history_length
    nb_actions = pytest.nb_actions
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    Q = CNN(agent_history_length, nb_actions).to(device)
    return Q
コード例 #3
0
def Q():
    '''
    Generate a Q function
    '''
    agent_history_length = pytest.agent_history_length
    nb_actions = gym.make(pytest.env_name).action_space.n
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    Q = CNN(agent_history_length, nb_actions).to(device)
    return Q
コード例 #4
0
def main(env_name):
    AGENT_HISTORY_LENGTH = 4
    env = gym.make(env_name)
    env = SkipFrames(env, AGENT_HISTORY_LENGTH - 1, preprocess)
    Q_network = CNN(AGENT_HISTORY_LENGTH, env.action_space.n)

    train_deepq(
        env=env,
        env_name=env_name,
        Q_network=Q_network,
        input_as_images=True,
    )
コード例 #5
0
def test_get_training_data(replay_memory):
    nb_actions, nb_timesteps, replay_memory = replay_memory
    Q_hat = CNN(agent_history_length=nb_timesteps, nb_actions=nb_actions)
    batch_size = 32
    phi_t_training, actions_training, y = get_training_data(
        Q_hat, replay_memory, batch_size, 0.99)
    assert type(phi_t_training) == torch.Tensor
    assert phi_t_training.shape[0] == batch_size

    assert type(actions_training) == list
    for action in actions_training:
        assert action < nb_actions
        assert action >= 0
    assert len(actions_training) == batch_size

    assert type(y) == torch.Tensor
    assert y.shape[0] == batch_size
コード例 #6
0
import gym
from deepq.wrapper_gym import KFrames
from deepq.deepq import train_deepq
from deepq.neural_nets import CNN
from deepq.utils import preprocess

AGENT_HISTORY_LENGTH = 4
NB_ACTIONS = 6
env = gym.make("PongNoFrameskip-v4")
env = KFrames(env, AGENT_HISTORY_LENGTH)
Q_network = CNN(AGENT_HISTORY_LENGTH, NB_ACTIONS)

train_deepq(
    env=env,
    name='Pong',
    nb_actions=NB_ACTIONS,
    Q_network=Q_network,
    preprocess_fn=preprocess,
    tensorboard_freq=5,
    demo_tensorboard=True,
)