def dqn_per_gridworld(): hp = DictConfig({}) hp.steps = 1000 hp.batch_size = 500 hp.replay_batch = 100 hp.replay_size = 1000 hp.delete_freq = 100 * (hp.batch_size + hp.replay_size) # every 100 steps hp.env_record_freq = 100 hp.env_record_duration = 25 hp.max_steps = 50 hp.grid_size = 4 hp.lr = 1e-3 hp.epsilon_exploration = 0.1 hp.gamma_discount = 0.9 model = (GenericConvModel(height=4, width=4, in_channels=4, channels=[50], out_size=4).float().to(device)) train_dqn_per( GridWorldEnvWrapper, model, hp, project_name="SimpleGridWorld", run_name="dqn_per", )
def breakout_double_dqn(): hp = DictConfig({}) hp.steps = 2000 hp.batch_size = 50 hp.replay_batch = 50 hp.replay_size = 1000 hp.delete_freq = 50 * (hp.batch_size + hp.replay_size) # every 100 steps hp.delete_percentage = 0.2 hp.env_record_freq = 100 hp.env_record_duration = 50 hp.lr = 1e-3 hp.gamma_discount = 0.9 # hp.epsilon_exploration = 0.1 hp.epsilon_flatten_step = 1500 hp.epsilon_start = 1 hp.epsilon_end = 0.1 hp.epsilon_decay_function = decay_functions.LINEAR hp.target_model_sync_freq = 50 model = GenericConvModel(42, 42, 3, [50, 50, 50], [100], 4) train_dqn_double( BreakoutEnvWrapper, model, hp, project_name="Breakout", run_name="double_dqn" )
def dqn_double(): hp = DictConfig({}) hp.steps = 1000 hp.batch_size = 500 hp.replay_batch = 100 hp.replay_size = 1000 hp.delete_freq = 100 * (hp.batch_size + hp.replay_size) # every 100 steps hp.env_record_freq = 100 hp.env_record_duration = 25 hp.max_steps = 50 hp.grid_size = 4 hp.lr = 1e-3 hp.gamma_discount = 0.9 # hp.epsilon_exploration = 0.1 hp.epsilon_flatten_step = 700 hp.epsilon_start = 1 hp.epsilon_end = 0.001 hp.epsilon_decay_function = decay_functions.LINEAR hp.target_model_sync_freq = 50 model = (GenericConvModel(height=4, width=4, in_channels=4, channels=[50], out_size=4).float().to(device)) train_dqn_double( GridWorldEnvWrapper, model, hp, project_name="SimpleGridWorld", run_name="dqn_target", )