def eval_ale_environment(model_path, render, num_of_epochs, steps_per_epoch, initial_epsilon, log_dir): # Create an ALE for game Breakout environment = ALEEnvironment(ALEEnvironment.RIVERRAID, is_render=render, max_episode_steps=5000) # Create a network configuration for Atari DQN network_config = PrioritizedAtariDQNConfig(environment) # Create a policy network for DQN agent network = PolicyNetwork(network_config, load_model_path=model_path) # Create DQN agent agent = DQNAgent(network, environment, initial_epsilon=initial_epsilon, report_frequency=1, num_of_threads=8, num_of_epochs=num_of_epochs, steps_per_epoch=steps_per_epoch, log_dir=log_dir, prioritized=True) # Evaluate it return agent.evaluate()
def train_ale_environment(): # Create an ALE for game Breakout environment = ALEEnvironment(ALEEnvironment.BREAKOUT) ################################################################################### # In case using Gym # environment = GymEnvironment("Breakout-v0", state_processor=AtariProcessor()) ################################################################################### # Create a network configuration for Atari DQN network_config = PrioritizedAtariDQNConfig(environment, initial_beta=0.4, initial_learning_rate=0.00025, debug_mode=True) # Create a policy network for DQN agent network = PolicyNetwork(network_config, max_num_of_checkpoints=100) # Create DQN agent agent = DQNAgent(network, environment, save_frequency=5e5, steps_per_epoch=1e6, num_of_epochs=50, exp_replay_size=2**19, importance_sampling=True, log_dir="./train/breakout/pdqn_check_point", prioritized_alpha=0.6, prioritized=True) # Train it agent.train()
def train_ale_environment(): # Create an ALE for game Breakout environment = ALEEnvironment(ALEEnvironment.BREAKOUT) # Create a network configuration for Atari DQN network_config = AtariDQNConfig(environment) # Put the configuration into a policy network network = PolicyNetwork(network_config) # Create a DQN agent agent = DQNAgent(network, environment, log_dir="./train/breakout/dqn_breakout") # Train it agent.train()
def train_ale_environment(): # Create an ALE for game Breakout environment = ALEEnvironment(ALEEnvironment.PONG) ################################################################################### # In case using Gym # environment = GymEnvironment("Breakout-v0", state_processor=AtariProcessor()) ################################################################################### # Create a network configuration for Atari DQN using Duel network network_config = AtariDuelDQNConfig(environment) # Create a policy network for DQN agent (create maximum of 40 checkpoints) network = PolicyNetwork(network_config, num_of_checkpoints=40) # Create DQN agent (Save checkpoint every 30 minutes, stop training at checkpoint 40th) agent = DQNAgent(network, environment, save_time_based=30, checkpoint_stop=40, log_dir="./train/Pong/dueldqn_pong_time_based_30_40") # Train it agent.train()