示例#1
0
                   new_height=84,
                   to_gray=True,
                   noop_action=[1, 0, 0, 0],
                   start_action=[0, 1, 0, 0],
                   clip_rewards=True)

test_env = AtariWrapper(env_name,
                        action_repeat=4,
                        obs_stack=4,
                        new_width=84,
                        new_height=84,
                        to_gray=True,
                        start_action=[0, 1, 0, 0])

agent = DeepQ(env=env,
              use_double=True,
              model=DeepQModel(nature_arch=True, dueling=False),
              optimizer=RMSProp(7e-4, decay=0.99, epsilon=0.1),
              policy=EGreedyPolicy(1.0, 0.1, 4000000),
              targetfreq=10000)

trainer = ReplayTrainer(env=env,
                        agent=agent,
                        maxsteps=80000000,
                        replay=backPropagationReplay,
                        logdir='tmp/%s/moving_average_bias' % env_name,
                        logfreq=900,
                        test_env=test_env,
                        test_render=False)
trainer.train()
示例#2
0
except ImportError:
    import os.path
    import sys
    sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
    import reinforceflow
from reinforceflow.agents import DeepQ
from reinforceflow.envs import Vectorize
from reinforceflow.core import EGreedyPolicy, ProportionalReplay
from reinforceflow.core import Adam
from reinforceflow.models import FullyConnected
from reinforceflow.trainers.replay_trainer import ReplayTrainer
reinforceflow.set_random_seed(555)

env_name = 'CartPole-v0'
env = Vectorize(env_name)
policy = EGreedyPolicy(eps_start=1.0, eps_final=0.2, anneal_steps=300000)

agent = DeepQ(env,
              model=FullyConnected(),
              optimizer=Adam(0.0001),
              targetfreq=10000,
              policy=EGreedyPolicy(1, 0.4, 300000))

trainer = ReplayTrainer(env=env,
                        agent=agent,
                        maxsteps=300000,
                        replay=ProportionalReplay(30000, 32, 32),
                        logdir='/tmp/rf/DeepQ/%s' % env_name,
                        logfreq=10)
trainer.train()
示例#3
0
                   new_height=84,
                   to_gray=True,
                   noop_action=[1, 0, 0, 0],
                   start_action=[0, 1, 0, 0],
                   clip_rewards=True)

test_env = AtariWrapper(env_name,
                        action_repeat=4,
                        obs_stack=4,
                        new_width=84,
                        new_height=84,
                        to_gray=True,
                        start_action=[0, 1, 0, 0])

agent = DeepQ(env=env,
              use_double=True,
              model=DeepQModel(nature_arch=True, dueling=False),
              optimizer=RMSProp(7e-4, decay=0.99, epsilon=0.1),
              policy=EGreedyPolicy(1.0, 0.1, 4000000),
              targetfreq=10000)

trainer = ReplayTrainer(env=env,
                        agent=agent,
                        maxsteps=80000000,
                        replay=backPropagationReplay,
                        logdir='tmp/%s/backpropagation' % env_name,
                        logfreq=1800,
                        test_env=test_env,
                        test_render=False)
trainer.train()
示例#4
0
env_name = 'CartPole-v0'
env = Vectorize(env_name)
policy = EGreedyPolicy(eps_start=1.0, eps_final=0.2, anneal_steps=300000)

agent = DeepQ(env,
              device='/cpu:0',
              model=FullyConnected(),
              optimizer=Adam(0.0001),
              targetfreq=10000,
              policy=EGreedyPolicy(1, 0.4, 300000))

backPropagationReplay = BackPropagationReplay(30000,
                                              32,
                                              0.,
                                              moving_average_accumulator,
                                              32,
                                              beta=20)

windowedPropagationReplay = WindowedBackPropagationReplay(30000, 32, 100, 32)

trainer = ReplayTrainer(env=env,
                        agent=agent,
                        maxsteps=300000,
                        replay=windowedPropagationReplay,
                        logdir='tmp/rf/DeepQ/%s' % env_name,
                        logfreq=10)

if __name__ == '__main__':
    trainer.train()
示例#5
0
                   new_height=84,
                   to_gray=True,
                   noop_action=[1, 0, 0, 0],
                   start_action=[0, 1, 0, 0],
                   clip_rewards=True)

test_env = AtariWrapper(env_name,
                        action_repeat=4,
                        obs_stack=4,
                        new_width=84,
                        new_height=84,
                        to_gray=True,
                        start_action=[0, 1, 0, 0])

agent = DeepQ(env,
              use_double=True,
              model=DeepQModel(nature_arch=True, dueling=False),
              optimizer=RMSProp(7e-4, decay=0.99, epsilon=0.1),
              policy=EGreedyPolicy(1.0, 0.1, 4000000),
              targetfreq=10000)

trainer = ReplayTrainer(env=env,
                        agent=agent,
                        maxsteps=80000000,
                        replay=ProportionalReplay(400000, 32, 32),
                        logdir='tmp/%s/proportional2' % env_name,
                        logfreq=1800,
                        test_env=test_env,
                        test_render=False)
trainer.train()
                   clip_rewards=True)

test_env = AtariWrapper(env_name,
                        action_repeat=4,
                        obs_stack=4,
                        new_width=84,
                        new_height=84,
                        to_gray=True,
                        start_action=[0, 1, 0, 0])

agent = DeepQ(env=env,
              model=DeepQModel(nature_arch=True, dueling=False),
              use_double = True,
              restore_from = None,
              optimizer = RMSProp(7e-4, decay=0.99, epsilon=0.1),
              policy = EGreedyPolicy(1.0, 0.1, 4000000),
              targetfreq = 10000,
              )

trainer = ReplayTrainer(env=env,
                       agent=agent,
                       maxsteps=80000000,
                       replay = ExperienceReplay(400000,32,32),
                       logdir='tmp/%s/vanilla' % env_name,
                       logfreq=1800,
                       render = False,
                       test_env = test_env,
                       test_render = False,
                       )
trainer.train()