def test_forward(batch_size, hidden_channels, kernel_sizes, strides):
    env_spec = GymEnv(DummyBoxEnv(obs_dim=(3, 10, 10))).spec
    obs_dim = env_spec.observation_space.shape
    obs = torch.zeros((batch_size, ) + obs_dim, dtype=torch.float32)

    qf = DiscreteCNNQFunction(env_spec=env_spec,
                              kernel_sizes=kernel_sizes,
                              strides=strides,
                              mlp_hidden_nonlinearity=None,
                              cnn_hidden_nonlinearity=None,
                              hidden_channels=hidden_channels,
                              hidden_sizes=hidden_channels,
                              hidden_w_init=nn.init.ones_,
                              output_w_init=nn.init.ones_,
                              is_image=False)

    output = qf(obs)
    expected_output = torch.zeros(output.shape)

    assert output.shape == (batch_size, env_spec.action_space.flat_dim)
    assert torch.eq(output, expected_output).all()
def test_is_pickleable(batch_size, hidden_channels, kernel_sizes, strides):
    env_spec = GymEnv(DummyBoxEnv(obs_dim=(3, 10, 10))).spec
    obs_dim = env_spec.observation_space.shape
    obs = torch.ones((batch_size, ) + obs_dim, dtype=torch.float32)

    qf = DiscreteCNNQFunction(env_spec=env_spec,
                              kernel_sizes=kernel_sizes,
                              strides=strides,
                              mlp_hidden_nonlinearity=None,
                              cnn_hidden_nonlinearity=None,
                              hidden_channels=hidden_channels,
                              hidden_sizes=hidden_channels,
                              hidden_w_init=nn.init.ones_,
                              output_w_init=nn.init.ones_,
                              is_image=False)

    output1 = qf(obs)
    p = pickle.dumps(qf)
    qf_pickled = pickle.loads(p)
    output2 = qf_pickled(obs)

    assert torch.eq(output1, output2).all()
Exemple #3
0
def dqn_atari(ctxt=None,
              env=None,
              seed=24,
              n_workers=psutil.cpu_count(logical=False),
              max_episode_length=None,
              **kwargs):
    """Train DQN with PongNoFrameskip-v4 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        env (str): Name of the atari environment, eg. 'PongNoFrameskip-v4'.
        seed (int): Used to seed the random number generator to produce
            determinism.
        n_workers (int): Number of workers to use. Defaults to the number of
            CPU cores available.
        max_episode_length (int): Max length of an episode. If None, defaults
            to the timelimit specific to the environment. Used by integration
            tests.
        kwargs (dict): hyperparameters to be saved to variant.json.

    """
    assert n_workers > 0
    assert env is not None
    env = gym.make(env)
    env = Noop(env, noop_max=30)
    env = MaxAndSkip(env, skip=4)
    env = EpisodicLife(env)
    if 'FIRE' in env.unwrapped.get_action_meanings():
        env = FireReset(env)
    env = Grayscale(env)
    env = Resize(env, 84, 84)
    env = ClipReward(env)
    env = StackFrames(env, 4, axis=0)
    env = GymEnv(env, max_episode_length=max_episode_length, is_image=True)
    set_seed(seed)
    trainer = Trainer(ctxt)

    n_epochs = hyperparams['n_epochs']
    steps_per_epoch = hyperparams['steps_per_epoch']
    sampler_batch_size = hyperparams['sampler_batch_size']
    num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
    replay_buffer = PathBuffer(
        capacity_in_transitions=hyperparams['buffer_size'])

    qf = DiscreteCNNQFunction(
        env_spec=env.spec,
        image_format='NCHW',
        hidden_channels=hyperparams['hidden_channels'],
        kernel_sizes=hyperparams['kernel_sizes'],
        strides=hyperparams['strides'],
        hidden_w_init=(
            lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))),
        hidden_sizes=hyperparams['hidden_sizes'])

    policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
    exploration_policy = EpsilonGreedyPolicy(
        env_spec=env.spec,
        policy=policy,
        total_timesteps=num_timesteps,
        max_epsilon=hyperparams['max_epsilon'],
        min_epsilon=hyperparams['min_epsilon'],
        decay_ratio=hyperparams['decay_ratio'])

    sampler = LocalSampler(agents=exploration_policy,
                           envs=env,
                           max_episode_length=env.spec.max_episode_length,
                           worker_class=FragmentWorker,
                           n_workers=n_workers)

    algo = DQN(env_spec=env.spec,
               policy=policy,
               qf=qf,
               exploration_policy=exploration_policy,
               replay_buffer=replay_buffer,
               sampler=sampler,
               steps_per_epoch=steps_per_epoch,
               qf_lr=hyperparams['lr'],
               clip_gradient=hyperparams['clip_gradient'],
               discount=hyperparams['discount'],
               min_buffer_size=hyperparams['min_buffer_size'],
               n_train_steps=hyperparams['n_train_steps'],
               target_update_freq=hyperparams['target_update_freq'],
               buffer_batch_size=hyperparams['buffer_batch_size'])

    set_gpu_mode(False)
    torch.set_num_threads(1)
    if torch.cuda.is_available():
        set_gpu_mode(True)
        algo.to()

    trainer.setup(algo, env)

    trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
    env.close()