def dqn_gridworld(): hp = DictConfig({}) hp.steps = 1000 hp.batch_size = 600 hp.env_record_freq = 100 hp.env_record_duration = 25 hp.max_steps = 50 hp.grid_size = 4 hp.lr = 1e-3 hp.epsilon_exploration = 0.1 hp.gamma_discount = 0.9 model = (GenericConvModel(height=4, width=4, in_channels=4, channels=[50], out_size=4).float().to(device)) train_dqn(GridWorldEnvWrapper, model, hp, project_name="SimpleGridWorld", run_name="dqn")
def train_dqn_connect4(): hp = DictConfig({}) hp.steps = 20 hp.batch_size = 2 hp.max_steps = 10 hp.lr = 1e-3 hp.epsilon_exploration = 0.1 hp.gamma_discount = 0.9 model = GenericLinearModel(2 * 6 * 7, [10], 7, flatten=True).float().to(device) train_dqn(ConnectXEnvWrapper, model, hp, name="Connect4")
def breakout_dqn(): hp = DictConfig({}) hp.steps = 2000 hp.batch_size = 32 hp.env_record_freq = 500 hp.env_record_duration = 100 hp.max_steps = 1000 hp.lr = 1e-3 hp.epsilon_exploration = 0.1 hp.gamma_discount = 0.9 model = GenericLinearModel(42 * 42 * 3, [100, 100], 4, flatten=True) train_dqn( BreakoutEnvWrapper, model, hp, project_name="Breakout", run_name="vanilla_dqn" )
def __init__(self): super().__init__() self.env = FrozenLakeEnv(map_name="4x4", is_slippery=True) def get_legal_actions(self): return list(range(4)) @staticmethod def get_state_batch(envs: Iterable) -> torch.Tensor: return to_onehot([env.state for env in envs], 16).float() if __name__ == "__main__": hp = DictConfig({}) hp.steps = 5000 hp.batch_size = 500 hp.max_steps = 200 hp.lr = 1e-3 hp.epsilon_exploration = 0.1 hp.gamma_discount = 0.9 hp.units = [10] model = GenericLinearModel(16, hp.units, 4).double().to(device) train_dqn(FrozenLakeEnvWrapper, model, hp, name="FrozenLake")
max_steps = 500 reward_range = (-10, 10) # TODO: Fix this def __init__(self): super().__init__() self.env = gym.make( "GDY-Sokoban---2-v0", global_observer_type=gd.ObserverType.VECTOR, player_observer_type=gd.ObserverType.VECTOR, level=0, ) if __name__ == "__main__": hp = DictConfig({}) hp.steps = 10000 hp.batch_size = 1000 hp.env_record_freq = 500 hp.env_record_duration = 50 hp.max_steps = 200 hp.lr = 1e-3 hp.epsilon_exploration = 0.1 hp.gamma_discount = 0.9 model = GenericLinearModel(5 * 7 * 8, [10], 5, flatten=True).float().to(device) train_dqn(SokobanV2L0EnvWrapper, model, hp, name="SokobanV2L0")
def __init__(self): super().__init__() self.env = gym.make("Taxi-v3") def get_legal_actions(self): return list(range(6)) @staticmethod def get_state_batch(envs: Iterable) -> torch.Tensor: return to_onehot([env.state for env in envs], 500).float() if __name__ == "__main__": hp = DictConfig({}) hp.steps = 10000 hp.batch_size = 500 hp.max_steps = 200 hp.lr = 1e-3 hp.epsilon_exploration = 0.1 hp.gamma_discount = 0.9 hp.units = [100] model = GenericLinearModel(in_size=500, units=hp.units, out_size=6) train_dqn(TaxiV3EnvWrapper, model, hp, name="TaxiV3")
from envs.env_wrapper import ( PettingZooEnvWrapper, NumpyStateMixin, petting_zoo_random_player, ) from models import GenericLinearModel from settings import device class TicTacToeEnvWrapper(PettingZooEnvWrapper, NumpyStateMixin): def __init__(self): super(TicTacToeEnvWrapper, self).__init__( env=tictactoe_v3.env(), opponent_policy=petting_zoo_random_player ) if __name__ == "__main__": hp = DictConfig({}) hp.steps = 20 hp.batch_size = 2 hp.max_steps = 10 hp.lr = 1e-3 hp.epsilon_exploration = 0.1 hp.gamma_discount = 0.9 model = GenericLinearModel(18, [10], 9, flatten=True).float().to(device) train_dqn(TicTacToeEnvWrapper, model, hp, name="TicTacToe")
model = (nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1), nn.ELU(), nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), nn.ELU(), nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), nn.ELU(), nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), nn.Flatten(), nn.Linear(288, 100), nn.ELU(), nn.Linear(100, 12), ).float().to(device)) if __name__ == "__main__": hp = DictConfig({}) hp.steps = 2000 hp.batch_size = 2 hp.env_record_freq = 500 hp.env_record_duration = 100 hp.max_steps = 1000 hp.lr = 1e-3 hp.epsilon_exploration = 0.1 hp.gamma_discount = 0.9 train_dqn(MarioEnvWrapper, model, hp, name="Mario")