def environment(game):
    """Atari environment."""
    env = atari_lib.create_atari_environment(game_name=game,
                                             sticky_actions=True)
    env = AtariDopamineWrapper(env)
    env = wrappers.FrameStackingWrapper(env, num_frames=4)
    return wrappers.SinglePrecisionWrapper(env)
Пример #2
0
    def testCreateAtariEnvironment(self, mock_gym_make, mock_atari_lib):
        class MockGymEnv(object):
            def __init__(self, env_name):
                self.env = 'gym({})'.format(env_name)

        def fake_make_env(name):
            return MockGymEnv(name)

        mock_gym_make.side_effect = fake_make_env
        # pylint: disable=unnecessary-lambda
        mock_atari_lib.side_effect = lambda x: 'atari({})'.format(x)
        # pylint: enable=unnecessary-lambda
        game_name = 'Test'
        env = atari_lib.create_atari_environment(game_name)
        self.assertEqual('atari(gym(TestNoFrameskip-v0))', env)
Пример #3
0
 def testDefaultDQNConfig(self):
     """Verify the default DQN configuration."""
     run_experiment.load_gin_configs(
         ['dopamine/agents/dqn/configs/dqn.gin'], [])
     agent = run_experiment.create_agent(
         self.test_session(),
         atari_lib.create_atari_environment(game_name='Pong'))
     self.assertEqual(agent.gamma, 0.99)
     self.assertEqual(agent.update_horizon, 1)
     self.assertEqual(agent.min_replay_history, 20000)
     self.assertEqual(agent.update_period, 4)
     self.assertEqual(agent.target_update_period, 8000)
     self.assertEqual(agent.epsilon_train, 0.01)
     self.assertEqual(agent.epsilon_eval, 0.001)
     self.assertEqual(agent.epsilon_decay_period, 250000)
     self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
     self.assertEqual(agent._replay.memory._batch_size, 32)
Пример #4
0
 def testDefaultRainbowConfig(self):
     """Verify the default Rainbow configuration."""
     run_experiment.load_gin_configs(
         ['dopamine/agents/rainbow/configs/rainbow.gin'], [])
     agent = run_experiment.create_agent(
         self.test_session(),
         atari_lib.create_atari_environment(game_name='Pong'))
     self.assertEqual(agent._num_atoms, 51)
     support = self.evaluate(agent._support)
     self.assertEqual(min(support), -10.)
     self.assertEqual(max(support), 10.)
     self.assertEqual(len(support), 51)
     self.assertEqual(agent.gamma, 0.99)
     self.assertEqual(agent.update_horizon, 3)
     self.assertEqual(agent.min_replay_history, 20000)
     self.assertEqual(agent.update_period, 4)
     self.assertEqual(agent.target_update_period, 8000)
     self.assertEqual(agent.epsilon_train, 0.01)
     self.assertEqual(agent.epsilon_eval, 0.001)
     self.assertEqual(agent.epsilon_decay_period, 250000)
     self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
     self.assertEqual(agent._replay.memory._batch_size, 32)
def main(unused_argv):
    _ = unused_argv
    tf.disable_eager_execution()
    logging.set_verbosity(logging.INFO)
    gin_files = FLAGS.gin_files
    gin_bindings = FLAGS.gin_bindings
    gin.parse_config_files_and_bindings(gin_files,
                                        bindings=gin_bindings,
                                        skip_unknown=False)

    paths = list(pathlib.Path(FLAGS.checkpoint_dir).parts)
    run_number = paths[-1].split('_')[-1]
    save_dir = osp.join(pathlib.Path(*paths), 'coherence',
                        f'batch_size_{FLAGS.batch_size}')
    ckpt_dir = osp.join(FLAGS.checkpoint_dir, 'checkpoints')
    if gfile.Exists(save_dir):
        gfile.DeleteRecursively(save_dir)
    gfile.MakeDirs(save_dir)
    logging.info('Checkpoint directory: %s', ckpt_dir)
    logging.info('Save coherence computation in directory: %s', save_dir)

    logging.info('Game: %s', FLAGS.game)
    environment = atari_lib.create_atari_environment(game_name=FLAGS.game,
                                                     sticky_actions=True)

    agent = create_agent(environment)

    checkpoints = get_checkpoints(ckpt_dir)

    replay_dir = create_game_replay_dir(FLAGS.game, run_number)
    logging.info('Replay dir: %s', replay_dir)

    replay_batch_size = 256
    num_batches = max(FLAGS.batch_size // replay_batch_size, 1)
    if FLAGS.debug_mode:
        states = [np.random.rand(replay_batch_size, 84, 84, 4)] * num_batches
        next_states = [np.random.rand(replay_batch_size, 84, 84, 4)
                       ] * num_batches
    else:
        data_replay = fixed_replay_buffer.FixedReplayBuffer(
            data_dir=replay_dir,
            replay_suffix=None,  # To load a specific buffer among the 50 buffers
            observation_shape=NATURE_DQN_OBSERVATION_SHAPE,
            stack_size=NATURE_DQN_STACK_SIZE,
            update_horizon=1,
            replay_capacity=FLAGS.replay_capacity,
            batch_size=replay_batch_size,
            gamma=FLAGS.gamma,
            observation_dtype=NATURE_DQN_DTYPE.as_numpy_dtype)
        data_replay.reload_buffer(FLAGS.num_buffers)

        states = []
        next_states = []
        for _ in range(num_batches):
            transitions = data_replay.sample_transition_batch()
            states.append(transitions[0])
            next_states.append(transitions[3])

    checkpoint_every = FLAGS.checkpoint_every
    max_checkpoints = int(len(checkpoints) // checkpoint_every)

    coherences = []
    all_norms = []
    feature_matrices = []
    for mdx in range(max_checkpoints + 1):
        checkpoint_num = mdx * checkpoint_every - 1
        logging.info('Checkpoint %d', checkpoint_num)
        # Checkpoint -1 corresponds to a random agent.
        if checkpoint_num >= 0:
            reload_checkpoint(agent, checkpoints[checkpoint_num])

        if FLAGS.residual_td:
            feature_matrix = get_features(agent, states)
            feature_matrix -= FLAGS.gamma * get_features(agent, next_states)
        else:
            feature_matrix = get_features(agent, states)
        if mdx <= 5:
            feature_matrices.append(feature_matrix)
        try:
            coherence, norms = calculate_coherence(feature_matrix)
            logging.info('Coherence: %0.2f', coherence)
            all_norms.append(norms)
            coherences.append(coherence)
        except Exception as e:  # pylint:disable=broad-except
            logging.info('Exception %s for checkpoint %d', e, checkpoint_num)
            continue

    prefix = 'residual_' if FLAGS.residual_td else ''

    logging.info('Number of checkpoints: %d', len(checkpoints))
    with gfile.Open(osp.join(save_dir, f'{prefix}coherence.npy'), 'wb') as f:
        np.save(f, coherences, allow_pickle=True)

    with gfile.Open(osp.join(save_dir, f'{prefix}norms.npy'), 'wb') as f:
        np.save(f, all_norms, allow_pickle=True)

    with gfile.Open(osp.join(save_dir, f'{prefix}features.npy'), 'wb') as f:
        np.save(f, feature_matrices, allow_pickle=True)
Пример #6
0
 def testCreateAtariEnvironmentWithoutGameName(self):
   with self.assertRaises(AssertionError):
     atari_lib.create_atari_environment()
Пример #7
0
 def create_atari_env_fn():
   """Creates the appropriate atari environement."""
   return atari_lib.create_atari_environment(FLAGS.game)