예제 #1
0
def main():
    # Create a Gym environment in FruitAPI
    env = GymEnvironment(env_name='CartPole-v1')

    # Create a PPO learner
    agent = AgentFactory.create(TensorForcePlugin().get_learner(),
                                None, env, num_of_epochs=1, steps_per_epoch=5e4, log_dir='train/ppo_checkpoints',
                                checkpoint_frequency=5e4,
                                # TensorForce parameters
                                algorithm='ppo', network='auto',
                                # Optimization
                                batch_size=10, update_frequency=2, learning_rate=1e-3, subsampling_fraction=0.2,
                                optimization_steps=5,
                                # Reward estimation
                                likelihood_ratio_clipping=0.2, discount=0.99, estimate_terminal=False,
                                # Critic
                                critic_network='auto',
                                critic_optimizer=dict(optimizer='adam', multi_step=10, learning_rate=1e-3),
                                preprocessing=None,
                                # Exploration
                                exploration=0.0, variable_noise=0.0,
                                # Regularization
                                l2_regularization=0.0, entropy_regularization=0.0,
                                # TensorFlow etc
                                name='agent', device=None, parallel_interactions=1, seed=None, execution=None,
                                saver=None, summarizer=None, recorder=None
                                )

    # Train it
    agent.train()
예제 #2
0
def train_atari_sea_quest():
    env = ALEEnvironment(ALEEnvironment.SEAQUEST, frame_skip=8)

    network_config = AtariA3CConfig(env, initial_learning_rate=0.004)

    network = PolicyNetwork(network_config, max_num_of_checkpoints=40)

    agent = AgentFactory.create(A3CLearner, network, env, num_of_epochs=40, steps_per_epoch=1e6,
                                checkpoint_frequency=1e6, log_dir='./train/sea_quest/a3c_checkpoints')

    agent.train()
예제 #3
0
def train_mc_grid_world():
    engine = GridWorld(render=False, graphical_state=False, stage=1,
                       number_of_rows=8, number_of_columns=9, speed=1000, seed=100, agent_start_x=2, agent_start_y=2)

    environment = FruitEnvironment(game_engine=engine)

    agent = AgentFactory.create(MCLearner, network=None, environment=environment, checkpoint_frequency=1e5,
                                num_of_epochs=1, steps_per_epoch=1e5, learner_report_frequency=10,
                                log_dir='./train/grid_world/mc_checkpoints')

    agent.train()
예제 #4
0
def train_multi_objective_agent_deep_sea_treasure(env_size):
    # Create a Deep Sea Treasure
    game = DeepSeaTreasure(width=env_size, seed=100, speed=1000)

    # Put the game engine into fruit wrapper
    environment = FruitEnvironment(game)

    # Create a multi-objective agent using Q-learning
    agent = AgentFactory.create(MOQLearner, None, environment, num_of_epochs=2, steps_per_epoch=100000,
                                checkpoint_frequency=5e4, log_dir='./train/deep_sea_treasure/moq_checkpoints')

    # Train it
    agent.train()
예제 #5
0
def train_multi_objective_agent_mountain_car():
    # Create a Mountain Car game
    game = MountainCar(graphical_state=False, frame_skip=1, render=False, speed=1000, is_debug=False)

    # Put game into fruit wrapper and enable multi-objective feature
    environment = FruitEnvironment(game)

    # Create a multi-objective agent using Q-learning algorithm
    agent = AgentFactory.create(MOQLearner, None, environment, num_of_epochs=30, steps_per_epoch=100000,
                                checkpoint_frequency=1e5, log_dir='./train/mountain_car/moq_checkpoints',
                                is_linear=True, thresholds=[0.5, 0.3, 0.2])

    # Train the agent
    agent.train()
예제 #6
0
def eval_mc_grid_world():
    engine = GridWorld(render=True, graphical_state=False, stage=1,
                       number_of_rows=8, number_of_columns=9, speed=2, seed=100, agent_start_x=2, agent_start_y=2)

    environment = FruitEnvironment(game_engine=engine)

    agent = AgentFactory.create(MCLearner, network=None, environment=environment, checkpoint_frequency=1e5,
                                num_of_epochs=1, steps_per_epoch=1e4, learner_report_frequency=50,
                                log_dir='./test/grid_world/mc_checkpoints',
                                load_model_path='./train/grid_world/mc_checkpoints_11-02-2019-02-29/'
                                                'checkpoint_100315.npy',
                                epsilon_annealing_start=0)

    agent.evaluate()
예제 #7
0
def train_tank_1_player_machine():
    game_engine = TankBattle(render=False, player1_human_control=False, player2_human_control=False,
                             two_players=False, speed=2000, frame_skip=5)

    env = FruitEnvironment(game_engine, max_episode_steps=10000, state_processor=AtariProcessor(),
                           reward_processor=TankBattleTotalRewardProcessor())

    network_config = AtariA3CConfig(env, initial_learning_rate=0.004)

    network = PolicyNetwork(network_config, max_num_of_checkpoints=20)

    agent = AgentFactory.create(A3CLearner, network, env, num_of_epochs=10, steps_per_epoch=1e6,
                                checkpoint_frequency=5e5, log_dir='./train/tank_battle/a3c_checkpoints')

    agent.train()
예제 #8
0
def train_ale_environment():
    # Create an ALE for Breakout
    environment = ALEEnvironment(ALEEnvironment.BREAKOUT)

    # Create a network configuration for Atari A3C
    network_config = AtariA3CConfig(environment, initial_learning_rate=0.004, debug_mode=True)

    # Create a shared network for A3C agent
    network = PolicyNetwork(network_config, max_num_of_checkpoints=40)

    # Create an A3C agent
    agent = AgentFactory.create(A3CLearner, network, environment, num_of_epochs=40, steps_per_epoch=1e6,
                                checkpoint_frequency=1e6, log_dir='./train/breakout/a3c_checkpoints')

    # Train it
    agent.train()
예제 #9
0
def train_milk_1_milk_1_fix_robots_with_no_status():
    game_engine = MilkFactory(render=False, speed=6000, max_frames=200, frame_skip=1, number_of_milk_robots=1,
                              number_of_fix_robots=1, number_of_milks=1, seed=None, human_control=False,
                              error_freq=0.03, human_control_robot=0, milk_speed=3, debug=False,
                              action_combined_mode=False, show_status=False)

    environment = FruitEnvironment(game_engine, max_episode_steps=200, state_processor=AtariProcessor())

    network_config = MAA3CConfig(environment, initial_learning_rate=0.001, beta=0.001)

    network = PolicyNetwork(network_config, max_num_of_checkpoints=40)

    agent = AgentFactory.create(MAA3CLearner, network, environment, num_of_epochs=40, steps_per_epoch=1e5,
                                checkpoint_frequency=1e5, log_dir='./train/milk_factory/a3c_ma_2_checkpoints')

    agent.train()
예제 #10
0
파일: a3c_test.py 프로젝트: vhtu/Fruit-API
def evaluate_ale_environment():
    # Create an ALE for Breakout and enable rendering
    environment = ALEEnvironment(ALEEnvironment.BREAKOUT, is_render=True)

    # Create a network configuration for Atari A3C
    network_config = AtariA3CConfig(environment)

    # Create a shared network for A3C agent
    network = PolicyNetwork(network_config,
                            load_model_path='./train/breakout/a3c_checkpoints_10-23-2019-02-13/model-39030506')

    # Create an A3C agent, use only one learner as we want to show a GUI
    agent = AgentFactory.create(A3CLearner, network, environment, num_of_epochs=1, steps_per_epoch=10000,
                                num_of_learners=1, log_dir='./test/breakout/a3c_checkpoints')

    # Evaluate it
    agent.evaluate()
예제 #11
0
파일: chapter_4.py 프로젝트: vhtu/Fruit-API
def composite_agents(main_model_path, auxiliary_model_path, alpha, epsilon):
    # Create a normal Breakout environment without negative reward
    environment = ALEEnvironment(ALEEnvironment.BREAKOUT)

    # Create a divide and conquer network configuration for Atari A3C
    network_config = DQAtariA3CConfig(environment)

    # Create a shared policy network
    network = PolicyNetwork(network_config, load_model_path=main_model_path)

    # Create an A3C agent
    agent = AgentFactory.create(DQA3CLearner, network, environment, num_of_epochs=1, steps_per_epoch=10000,
                                checkpoint_frequency=1e5, learner_report_frequency=1,
                                auxiliary_model_path=auxiliary_model_path, alpha=alpha, epsilon=epsilon)

    # Test it
    return agent.evaluate()
예제 #12
0
파일: chapter_4.py 프로젝트: vhtu/Fruit-API
def train_breakout_with_a3c_remove_immutable_objects():
    # Create an ALE for game Breakout, blacken top half of the state
    environment = ALEEnvironment(ALEEnvironment.BREAKOUT,
                                 loss_of_life_negative_reward=True,
                                 state_processor=AtariBlackenProcessor())

    # Create a network configuration for Atari A3C
    network_config = AtariA3CConfig(environment, initial_learning_rate=0.004, debug_mode=True)

    # Create a shared network for A3C agent
    network = PolicyNetwork(network_config, max_num_of_checkpoints=50)

    # Create an A3C agent
    agent = AgentFactory.create(A3CLearner, network, environment, num_of_epochs=50, steps_per_epoch=1e6,
                                checkpoint_frequency=1e6, log_dir='./train/breakout/a3c_smc_1_checkpoints')

    # Train it
    agent.train()
예제 #13
0
def train_tank_1_player_machine_with_map():
    def update_reward(rewards):
        return rewards[2]

    game_engine = TankBattle(render=False, player1_human_control=False, player2_human_control=False,
                             two_players=False, speed=1000, frame_skip=5, debug=False,
                             using_map=True, num_of_enemies=5, multi_target=True, strategy=3
                             )

    env = FruitEnvironment(game_engine, max_episode_steps=10000, state_processor=AtariProcessor())

    network_config = A3CMapConfig(env, initial_learning_rate=0.004)

    network = PolicyNetwork(network_config, max_num_of_checkpoints=20)

    agent = AgentFactory.create(A3CMapLearner, network, env, num_of_epochs=10, steps_per_epoch=1e6,
                                checkpoint_frequency=5e5, log_dir='./train/tank_battle/a3c_map_checkpoints',
                                network_update_steps=4, update_reward_fnc=update_reward)

    agent.train()
예제 #14
0
def train_atari_sea_quest_with_map():
    def update_reward(rewards):
        oxy_low = rewards[3]
        if oxy_low == 1:
            reward = rewards[2]
        else:
            reward = rewards[0] + rewards[1]
        return reward

    env = ALEEnvironment(ALEEnvironment.SEAQUEST, state_processor=SeaquestMapProcessor(), frame_skip=8)

    network_config = A3CMapConfig(env, initial_learning_rate=0.004)

    network = PolicyNetwork(network_config, max_num_of_checkpoints=40)

    agent = AgentFactory.create(A3CMapLearner, network, env, num_of_epochs=40, steps_per_epoch=1e6,
                                checkpoint_frequency=1e6, log_dir='./train/sea_quest/a3c_map_checkpoints',
                                network_update_steps=12, update_reward_fnc=update_reward)

    agent.train()
예제 #15
0
파일: dqn_test.py 프로젝트: vhtu/Fruit-API
def train_ale_environment():
    # Create an ALE for Breakout
    environment = ALEEnvironment(ALEEnvironment.BREAKOUT)

    # Create a network configuration for Atari DQN
    network_config = AtariDQNConfig(environment, debug_mode=True)

    # Put the configuration into a policy network
    network = PolicyNetwork(network_config, max_num_of_checkpoints=40)

    # Create a DQN agent
    agent = AgentFactory.create(DQNLearner,
                                network,
                                environment,
                                num_of_epochs=20,
                                steps_per_epoch=1e6,
                                checkpoint_frequency=5e5,
                                log_dir='./train/breakout/dqn_checkpoints')

    # Train it
    agent.train()
예제 #16
0
def train_multi_objective_dqn_agent(is_linear=True, extended_config=True):
    if extended_config:
        # Create a Deep Sea Treasure game
        game = DeepSeaTreasure(graphical_state=True, width=5, seed=100, render=False, max_treasure=100, speed=1000)

        # Put game into fruit wrapper
        environment = FruitEnvironment(game, max_episode_steps=60, state_processor=AtariProcessor())
    else:
        # Create a Deep Sea Treasure game
        game = DeepSeaTreasure(graphical_state=False, width=5, seed=100, render=False, max_treasure=100, speed=1000)

        # Put game into fruit wrapper
        environment = FruitEnvironment(game, max_episode_steps=60)

    # Get treasures
    treasures = game.get_treasure()
    if is_linear:
        tlo_thresholds = None
        linear_thresholds = [1, 0]
    else:
        tlo_thresholds = [(treasures[4] + treasures[3]) / 2]
        linear_thresholds = [10, 1]

    if extended_config:
        config = MOExDQNConfig(environment, is_linear=is_linear, linear_thresholds=linear_thresholds,
                               tlo_thresholds=tlo_thresholds, using_cnn=True, history_length=4)
    else:
        config = MODQNConfig(environment, is_linear=is_linear, linear_thresholds=linear_thresholds,
                             tlo_thresholds=tlo_thresholds)

    # Create a shared policy network
    network = PolicyNetwork(config, max_num_of_checkpoints=10)

    # Create a multi-objective DQN agent
    agent = AgentFactory.create(MODQNLearner, network, environment, num_of_epochs=2, steps_per_epoch=100000,
                                checkpoint_frequency=50000, log_dir='./train/deep_sea_treasure/mo_dqn_checkpoints')

    # Train it
    agent.train()