Beispiel #1
0
def warehouse():
    from gymcolab.envs.warehouse import Warehouse
    from models import VanillaDqnModel

    args = dict(
        buffersize=100000,
        lr=0.001,
        target_update_ratio=100,
        gamma=0.99,
        episodes=10000,
        update_begin=75,
        init_eps=0.5,
        terminal_eps=0.0,
        batchsize=32,
    )
    device = "cuda"

    env = Warehouse()
    n_nodes, height, width = env.observation_space.shape
    n_act = env.action_space.n - 1
    assert height == width, "Environment map must be square"
    model = VanillaDqnModel(n_nodes, height, n_act)
    optim = torch.optim.Adam(model.parameters(), lr=args["lr"])
    dqn = Dqn(model, optim, args["buffersize"])
    dqn.to(device)
    learn(args, env, dqn)
Beispiel #2
0
def warehouse_graph():
    from gymcolab.envs.warehouse import Warehouse
    from models import GraphDqnModel

    args = dict(
        buffersize=100000,
        lr=0.0001,
        target_update_ratio=100,
        gamma=0.99,
        episodes=10000,
        update_begin=75,
        init_eps=0.9,
        terminal_eps=0.1,
        batchsize=32,
    )
    device = "cuda"

    env = Warehouse()
    n_nodes, height, width = env.observation_space.shape
    n_act = env.action_space.n
    assert height == width, "Environment map must be square"
    adj = torch.tensor([[[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 0, 0],
                         [0, 0, 0, 0]],
                        [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1],
                         [0, 0, 0, 0]],
                        [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
                         [0, 1, 0, 0]]]).float().to(device)
    n_edges = 3

    class GraphDqnModelAdj(GraphDqnModel):
        def __init__(self, n_edges, n_nodes, mapsize, n_act, adj):
            super().__init__(n_edges, n_nodes, mapsize, n_act)
            self.adj = adj

        def forward(self, objmap):
            return super().forward(self.adj, objmap)

    model = GraphDqnModelAdj(n_edges, n_nodes, height, n_act, adj)
    optim = torch.optim.Adam(model.parameters(), lr=args["lr"])
    dqn = Dqn(model, optim, args["buffersize"])
    dqn.to(device)
    learn(args, env, dqn)