Ejemplo n.º 1
0
def nature_ddqn(env, frames=4):
    return nn.Sequential(
        nn.Scale(1 / 255), nn.Conv2d(frames, 32, 8, stride=4), nn.ReLU(),
        nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
        nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(), nn.Flatten(),
        nn.Dueling(
            nn.Sequential(nn.Linear(3136, 512), nn.ReLU(), nn.Linear0(512, 1)),
            nn.Sequential(nn.Linear(3136, 512), nn.ReLU(),
                          nn.Linear0(512, env.action_space.n)),
        ))
 def setUp(self):
     torch.manual_seed(2)
     self.model = nn.Sequential(nn.Linear0(STATE_DIM, ACTION_DIM))
     self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.01)
     self.space = Box(np.array([-1, -1, -1]), np.array([1, 1, 1]))
     self.policy = DeterministicPolicy(self.model, self.optimizer,
                                       self.space, 0.5)
Ejemplo n.º 3
0
def fc_relu_dist_q(env, hidden=64, atoms=51):
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(env.state_space.shape[0], hidden),
        nn.ReLU(),
        nn.Linear0(hidden, env.action_space.n * atoms),
    )
Ejemplo n.º 4
0
def fc_v(env, hidden1=400, hidden2=300):
    return nn.Sequential(
        nn.Linear(env.state_space.shape[0] + 1, hidden1),
        nn.ReLU(),
        nn.Linear(hidden1, hidden2),
        nn.ReLU(),
        nn.Linear0(hidden2, 1),
    )
Ejemplo n.º 5
0
def fc_soft_policy(env, hidden1=400, hidden2=300):
    return nn.Sequential(
        nn.Linear(env.state_space.shape[0] + 1, hidden1),
        nn.ReLU(),
        nn.Linear(hidden1, hidden2),
        nn.ReLU(),
        nn.Linear0(hidden2, env.action_space.shape[0] * 2),
    )
Ejemplo n.º 6
0
def fc_policy(env):
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(env.state_space.shape[0], 64),
        nn.ReLU(),
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear0(64, env.action_space.shape[0] * 2),
    )
Ejemplo n.º 7
0
def fc_v(env, hidden1=516, hidden2=516):
    print("Custom V loaded")
    return nn.Sequential(
        nn.Linear(env.state_space.shape[0] + 1, hidden1),
        nn.ReLU(),
        nn.Linear(hidden1, hidden2),
        nn.ReLU(),
        nn.Linear0(hidden2, 1),
    )
Ejemplo n.º 8
0
def nature_c51(env, frames=4, atoms=51):
    return nn.Sequential(
        nn.Scale(1/255),
        nn.Conv2d(frames, 32, 8, stride=4),
        nn.ReLU(),
        nn.Conv2d(32, 64, 4, stride=2),
        nn.ReLU(),
        nn.Conv2d(64, 64, 3, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3136, 512),
        nn.ReLU(),
        nn.Linear0(512, env.action_space.n * atoms)
    )
Ejemplo n.º 9
0
 def __init__(self, env, frames=4):
     super().__init__()
     n_agents = len(env.agents)
     n_actions = env.action_spaces['first_0'].n
     self.conv = nn.Sequential(
         nn.Scale(1/255),
         nn.Conv2d(frames, 32, 8, stride=4),
         nn.ReLU(),
         nn.Conv2d(32, 64, 4, stride=2),
         nn.ReLU(),
         nn.Conv2d(64, 64, 3, stride=1),
         nn.ReLU(),
         nn.Flatten()
     )
     self.hidden = nn.Linear(3136 + n_agents, 512)
     self.output = nn.Linear0(512 + n_agents, n_actions)
Ejemplo n.º 10
0
def fc_policy_head(env, hidden=64):
    return nn.Linear0(hidden, env.action_space.n)
Ejemplo n.º 11
0
def fc_value_head(hidden=64):
    return nn.Linear0(hidden, 1)
Ejemplo n.º 12
0
def nature_policy_head(env):
    return nn.Linear0(512, env.action_space.n)
Ejemplo n.º 13
0
def simple_nature_policy_head(env):
    return nn.Linear0(16, env.action_space.n)
Ejemplo n.º 14
0
def value_head():
    return nn.Linear0(512, 1)
Ejemplo n.º 15
0
def reward_head(env):
    return nn.Linear0(512, env.action_space.n)
Ejemplo n.º 16
0
def fc_value(env):
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(env.state_space.shape[0] + env.action_space.shape[0], 64),
        nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear0(64, 1))
Ejemplo n.º 17
0
 def test_linear0(self):
     model = nn.Linear0(3, 3)
     result = model(torch.tensor([[3.0, -2.0, 10]]))
     tt.assert_equal(result, torch.tensor([[0.0, 0.0, 0.0]]))