Exemplo n.º 1
0
def EnvFactory(env_name):
    parts = env_name.split(':')
    if len(parts) > 2:
        raise ValueError('Incorrect environment name %s' % env_name)

    if parts[0] == '2048':
        env = game2048.Game2048()
    else:
        env = gym.make(parts[0])

    if len(parts) == 2:
        for letter in parts[1]:
            if letter == 'L':
                env = atari_wrappers.EpisodicLifeEnv(env)
            elif letter == 'N':
                env = atari_wrappers.NoopResetEnv(env, noop_max=30)
            elif letter == 'S':
                env = atari_wrappers.MaxAndSkipEnv(env, skip=4)
            elif letter == 'X':
                env = atari_wrappers.StackAndSkipEnv(env, skip=3)
            elif letter == 'F':
                env = atari_wrappers.FireResetEnv(env)
            elif letter == 'C':
                env = atari_wrappers.ClippedRewardsWrapper(env)
            elif letter == 'P':
                env = atari_wrappers.ProcessFrame84(env)
            else:
                raise ValueError('Unexpected code of wrapper %s' % letter)
    return env
Exemplo n.º 2
0
def PrimaryAtariWrap(env, clip_rewards=True):
    assert 'NoFrameskip' in env.spec.id

    # This wrapper holds the same action for <skip> frames and outputs
    # the maximal pixel value of 2 last frames (to handle blinking
    # in some envs)
    env = atari_wrappers.MaxAndSkipEnv(env, skip=4)

    # This wrapper sends done=True when each life is lost
    # (not all the 5 lives that are givern by the game rules).
    # It should make easier for the agent to understand that losing is bad.
    env = atari_wrappers.EpisodicLifeEnv(env)

    # This wrapper laucnhes the ball when an episode starts.
    # Without it the agent has to learn this action, too.
    # Actually it can but learning would take longer.
    env = atari_wrappers.FireResetEnv(env)

    # This wrapper transforms rewards to {-1, 0, 1} according to their sign
    if clip_rewards:
        env = atari_wrappers.ClipRewardEnv(env)

    # This wrapper is yours :)
    env = PreprocessAtariObs(env)
    return env
Exemplo n.º 3
0
def PrimaryAtariWrap(env,
                     clip_rewards=True,
                     frame_skip=True,
                     fire_reset_event=False,
                     episodic_life=False,
                     width=44,
                     height=44,
                     margins=[1, 1, 1, 1],
                     n_frames=4,
                     reward_scale=0):

    # This wrapper holds the same action for <skip> frames and outputs
    # the maximal pixel value of 2 last frames (to handle blinking
    # in some envs)
    if frame_skip:
        env = atari_wrappers.MaxAndSkipEnv(env, skip=4)

    # This wrapper sends done=True when each life is lost
    # (not all the 5 lives that are givern by the game rules).
    # It should make easier for the agent to understand that losing is bad.
    if episodic_life:
        env = atari_wrappers.EpisodicLifeEnv(env)

    # This wrapper laucnhes the ball when an episode starts.
    # Without it the agent has to learn this action, too.
    # Actually it can but learning would take longer.
    if fire_reset_event:
        env = atari_wrappers.FireResetEnv(env)

    # This wrapper transforms rewards to {-1, 0, 1} according to their sign
    if clip_rewards:
        env = atari_wrappers.ClipRewardEnv(env)
    if reward_scale != 0:
        env = atari_wrappers.RewardScale(env, reward_scale)

    # This wrapper is yours :)
    env = atari_wrappers.PreprocessAtariObs(env,
                                            height=height,
                                            width=width,
                                            margins=margins)

    env = atari_wrappers.FrameBuffer(env,
                                     n_frames=n_frames,
                                     dim_order='pytorch')

    return env
Exemplo n.º 4
0
	args = parser.parse_args()

	ENVIRONMENT = args.env
	MODEL_FOLDER = args.model_folder
	MODEL_NAME = args.model_name
	CHECKPOINT = args.checkpoint
	TRAINING_INFO = args.training_info
	NETWORK = args.network
	CHECKPOINT_STEPS = args.checkpoint_steps
	additional_training_steps = args.add_training_steps


	########## LOAD ENVIRONMENT AND BUILD NETWORK ##########
	env = gym.make(ENVIRONMENT)
	env = atari_wrappers.wrap_deepmind(env, frame_stack=True, clip_rewards=True)
	env = atari_wrappers.MaxAndSkipEnv(env, skip=3)
	#env = atari_wrappers.CenteredScaledFloatFrame(env)
	
	optimizer = tf.train.RMSPropOptimizer(learning_rate=0.00025, momentum=0.95, epsilon=0.01)
	
	#optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
	use_target_network = True #False if NETWORK.startswith('double') else True
	use_double_dqn = True if NETWORK.startswith('double') else False
	
	net = dqn.get_network(type=NETWORK, input_shape=env.observation_space.shape, num_actions=env.action_space.n,
									use_target_network=use_target_network, use_double_dqn=use_double_dqn, optimizer=optimizer)

	
	#################### TRAINING AGENT ####################
	saver = tf.train.Saver(max_to_keep=10)