def __init__(self, input_shape, num_outputs, noisy=False, sigma_init=0.5, body=SimpleBody, atoms=51): super(CategoricalDuelingDQN, self).__init__() self.input_shape = input_shape self.num_actions = num_outputs self.noisy = noisy self.atoms = atoms self.body = body(input_shape, num_outputs, noisy, sigma_init) self.adv1 = nn.Linear(self.body.feature_size(), 512) if not self.noisy else NoisyLinear( self.body.feature_size(), 512, sigma_init) self.adv2 = nn.Linear(512, self.num_actions * self.atoms) if not self.noisy else NoisyLinear( 512, self.num_actions * self.atoms, sigma_init) self.val1 = nn.Linear(self.body.feature_size(), 512) if not self.noisy else NoisyLinear( self.body.feature_size(), 512, sigma_init) self.val2 = nn.Linear(512, 1 * self.atoms) if not self.noisy else NoisyLinear( 512, 1 * self.atoms, sigma_init)
def __init__(self, input_shape, num_actions, noisy=False, sigma_init=0.5, body=SimpleBody): super(DQN, self).__init__() self.input_shape = input_shape self.num_actions = num_actions self.noisy=noisy self.body = body(input_shape, num_actions, noisy, sigma_init) self.fc1 = nn.Linear(self.body.feature_size(), 512) if not self.noisy else NoisyLinear(self.body.feature_size(), 512, sigma_init) self.fc2 = nn.Linear(512, self.num_actions) if not self.noisy else NoisyLinear(512, self.num_actions, sigma_init)
def __init__(self, input_shape, num_actions, noisy=False, sigma_init=0.5, gru_size=512, bidirectional=False, body=SimpleBody): super(DRQN, self).__init__() self.input_shape = input_shape self.num_actions = num_actions self.noisy = noisy self.gru_size = gru_size self.bidirectional = bidirectional self.num_directions = 2 if self.bidirectional else 1 self.body = body(input_shape, num_actions, noisy=self.noisy, sigma_init=sigma_init) self.gru = nn.GRU(self.body.feature_size(), self.gru_size, num_layers=1, batch_first=True, bidirectional=bidirectional) self.fc2 = nn.Linear( self.gru_size, self.num_actions) if not self.noisy else NoisyLinear( self.gru_size, self.num_actions, sigma_init)
def __init__(self, input_shape, num_actions, noisy=False, sigma_init=0.5): super(SimpleBody, self).__init__() self.input_shape = input_shape self.num_actions = num_actions self.noisy=noisy self.fc1 = nn.Linear(input_shape[0], 128) if not self.noisy else NoisyLinear(input_shape[0], 128, sigma_init)
def __init__(self, input_shape, num_actions, sigma_init=0.5, atoms=51): super(CategoricalDuelingDQN, self).__init__() self.input_shape = input_shape self.num_actions = num_actions self.atoms = atoms self.conv1 = nn.Conv2d(self.input_shape[0], 32, kernel_size=8, stride=4) self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) self.adv1 = NoisyLinear(self.feature_size(), 512, sigma_init) self.adv2 = NoisyLinear(512, self.num_actions*self.atoms, sigma_init) self.val1 = NoisyLinear(self.feature_size(), 512, sigma_init) self.val2 = NoisyLinear(512, 1*self.atoms, sigma_init)
class CategoricalDuelingDQN(nn.Module): def __init__(self, input_shape, num_actions, sigma_init=0.5, atoms=51): super(CategoricalDuelingDQN, self).__init__() self.input_shape = input_shape self.num_actions = num_actions self.atoms = atoms self.conv1 = nn.Conv2d(self.input_shape[0], 32, kernel_size=8, stride=4) self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) self.adv1 = NoisyLinear(self.feature_size(), 512, sigma_init) self.adv2 = NoisyLinear(512, self.num_actions*self.atoms, sigma_init) self.val1 = NoisyLinear(self.feature_size(), 512, sigma_init) self.val2 = NoisyLinear(512, 1*self.atoms, sigma_init) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) adv = F.relu(self.adv1(x)) adv = self.adv2(adv).view(-1, self.num_actions, self.atoms) val = F.relu(self.val1(x)) val = self.val2(val).view(-1, 1, self.atoms) final = val + adv - adv.mean(dim=1).view(-1, 1, self.atoms) return F.softmax(final, dim=2) def feature_size(self): return self.conv3(self.conv2(self.conv1(torch.zeros(1, *self.input_shape)))).view(1, -1).size(1) def sample_noise(self): self.adv1.sample_noise() self.adv2.sample_noise() self.val1.sample_noise() self.val2.sample_noise()