class ActionNet(nn.Module): def __init__(self, config, entDim=19, outDim=2, n_quant=1, add_features=0, noise=False): super().__init__() self.config, self.h = config, config.HIDDEN self.entDim = entDim self.envNet = Env(config, noise=noise) self.fc = NoisyLinear(self.h + add_features, self.h) if noise else nn.Linear( self.h + add_features, self.h) self.actionNet = ConstDiscrete(config, self.h, outDim, n_quant) def forward(self, flat, ents, eps=0, punish=None, val=None, device='cpu'): stim = self.envNet(flat.to(device), ents.to(device), device=device) if self.config.VAL_FEATURE and val is not None: stim = torch.cat([stim.to(device), val.to(device)], dim=1) if self.config.NOISE: x = F.relu(self.fc(stim.to(device), device=device)) else: x = F.relu(self.fc(stim.to(device))) outs, idx = self.actionNet(x.to(device), eps, punish) return outs, idx def reset_noise(self): self.envNet.reset_noise() self.fc.reset_noise()
class ValNet(nn.Module): def __init__(self, config, n_quant=1, noise=False): super().__init__() self.config = config self.h = config.HIDDEN self.fc = NoisyLinear(self.h, self.h) if noise else nn.Linear( self.h, self.h) self.valNet = nn.Linear(self.h, n_quant) self.n_quant = n_quant def forward(self, s): x = F.relu(self.fc(s)) x = self.valNet(x).view(-1, 1, self.n_quant) return x def reset_noise(self): self.fc.reset_noise()
class Env(nn.Module): def __init__(self, config, noise=False): super().__init__() self.config = config h = config.HIDDEN self.conv = nn.Conv2d(3, 6, (3, 3)) self.fc = NoisyLinear(self.config.ENT_DIM, h) if noise else nn.Linear( self.config.ENT_DIM, h) def forward(self, env): x = F.relu(self.conv(env).view(env.shape[0], -1)) x = F.relu(self.fc(x)) return x def reset_noise(self): self.fc.reset_noise()
class Env(nn.Module): def __init__(self, config, noise=False): super().__init__() self.config = config h = config.HIDDEN entDim = config.ENT_DIM self.fc = NoisyLinear(2 * h, h) if noise else nn.Linear(2 * h, h) self.flat = nn.Linear(entDim, h) self.ents = Ent(entDim, h) def forward(self, flat, ents, device='cpu'): flat = self.flat(flat.to(device)) ents = self.ents(ents.to(device), device=device) x = torch.cat((flat, ents), dim=1) if self.config.NOISE: x = F.relu(self.fc(x.to(device), device=device)) else: x = F.relu(self.fc(x.to(device))) return x def reset_noise(self): self.fc.reset_noise()
class ActionNet(nn.Module): def __init__(self, config, args, outDim=8, n_quant=1, add_features=0, noise=False): super().__init__() self.config, self.args, self.h = config, args, config.HIDDEN self.fc = NoisyLinear(self.h + add_features, self.h) if noise else nn.Linear( self.h + add_features, self.h) self.actionNet = ConstDiscrete(config, self.h, outDim, n_quant) def forward(self, s, eps=0, punish=None, val=None): if self.config.VAL_FEATURE and val is not None: s = torch.cat([s, val], dim=1) x = F.relu(self.fc(s)) outs, idx = self.actionNet(x, eps, punish) return outs, idx def reset_noise(self): self.fc.reset_noise()