コード例 #1
0
def test_gray_scale_observation(env_id, keep_dim):
    gray_env = AtariPreprocessing(gym.make(env_id),
                                  screen_size=84,
                                  grayscale_obs=True)
    rgb_env = AtariPreprocessing(gym.make(env_id),
                                 screen_size=84,
                                 grayscale_obs=False)
    wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
    assert rgb_env.observation_space.shape[-1] == 3

    seed = 0
    gray_env.seed(seed)
    wrapped_env.seed(seed)

    gray_obs = gray_env.reset()
    wrapped_obs = wrapped_env.reset()

    if keep_dim:
        assert wrapped_env.observation_space.shape[-1] == 1
        assert len(wrapped_obs.shape) == 3
        wrapped_obs = wrapped_obs.squeeze(-1)
    else:
        assert len(wrapped_env.observation_space.shape) == 2
        assert len(wrapped_obs.shape) == 2

    # ALE gray scale is slightly different, but no more than by one shade
    assert np.allclose(gray_obs.astype("int32"),
                       wrapped_obs.astype("int32"),
                       atol=1)
コード例 #2
0
def test_atari_preprocessing_grayscale(env_fn):
    import cv2
    env1 = env_fn()
    env2 = AtariPreprocessing(env_fn(),
                              screen_size=84,
                              grayscale_obs=True,
                              frame_skip=1,
                              noop_max=0)
    env3 = AtariPreprocessing(env_fn(),
                              screen_size=84,
                              grayscale_obs=False,
                              frame_skip=1,
                              noop_max=0)
    env1.seed(0)
    env2.seed(0)
    env3.seed(0)
    obs1 = env1.reset()
    obs2 = env2.reset()
    obs3 = env3.reset()
    assert obs1.shape == (210, 160, 3)
    assert obs2.shape == (84, 84)
    assert obs3.shape == (84, 84, 3)
    assert np.allclose(
        obs3, cv2.resize(obs1, (84, 84), interpolation=cv2.INTER_AREA))
    obs3_gray = cv2.cvtColor(obs3, cv2.COLOR_RGB2GRAY)
    # the edges of the numbers do not render quite the same in the grayscale, so we ignore them
    assert np.allclose(obs2[10:38], obs3_gray[10:38])
    # the paddle also do not render quite the same
    assert np.allclose(obs2[44:], obs3_gray[44:])

    env1.close()
    env2.close()
    env3.close()
コード例 #3
0
def main(args):
    if args.cpu_only == True:
        cpu = tf.config.experimental.list_physical_devices(device_type='CPU')
        tf.config.experimental.set_visible_devices(devices=cpu, device_type='CPU')

    # random seed setting
    if args.random_seed <= 0:
        random_seed = np.random.randint(1, 9999)
    else:
        random_seed = args.random_seed

    tf.random.set_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)

    #env setting
    if args.domain_type == 'gym':
        #openai gym
        env = gym.make(args.env_name)
        env.seed(random_seed)
        env.action_space.seed(random_seed)

        test_env = gym.make(args.env_name)
        test_env.seed(random_seed)
        test_env.action_space.seed(random_seed)

    elif args.domain_type == 'dmc':
        #deepmind control suite
        env = dmc2gym.make(domain_name=args.env_name.split('/')[0], task_name=args.env_name.split('/')[1], seed=random_seed)
        test_env = dmc2gym.make(domain_name=args.env_name.split('/')[0], task_name=args.env_name.split('/')[1], seed=random_seed)

    elif args.domain_type == 'atari':
        #openai gym
        env = gym.make(args.env_name)
        env = AtariPreprocessing(env, frame_skip=args.frame_skip, screen_size=args.image_size, grayscale_newaxis=True)
        env = FrameStack(env, args.frame_stack)
        env._max_episode_steps = 10000
        env.seed(random_seed)
        env.action_space.seed(random_seed)

        test_env = gym.make(args.env_name)
        test_env = AtariPreprocessing(test_env, frame_skip=args.frame_skip, screen_size=args.image_size, grayscale_newaxis=True)
        test_env._max_episode_steps = 10000
        test_env = FrameStack(test_env, args.frame_stack)
        test_env.seed(random_seed)
        test_env.action_space.seed(random_seed)


    state_dim = env.observation_space.shape[0]

    if args.domain_type == 'atari':
        state_dim = env.observation_space.shape

    action_dim = env.action_space.n
    max_action = 1
    min_action = 1


    if args.domain_type is 'gym':
        algorithm = DQN(state_dim, action_dim, args)
    elif args.domain_type is 'dmc':
        algorithm = DQN(state_dim, action_dim, args)
    elif args.domain_type == 'atari':
        algorithm = ImageDQN(state_dim, action_dim, args)

    print("Training of", env.unwrapped.spec.id)
    print("Algorithm:", algorithm.name)
    print("State dim:", state_dim)
    print("Action dim:", action_dim)

    trainer = Basic_trainer(env, test_env, algorithm, max_action, min_action, args)
    trainer.run()