Beispiel #1
0
class ActionNet(nn.Module):
    def __init__(self,
                 config,
                 entDim=19,
                 outDim=2,
                 n_quant=1,
                 add_features=0,
                 noise=False):
        super().__init__()
        self.config, self.h = config, config.HIDDEN
        self.entDim = entDim
        self.envNet = Env(config, noise=noise)
        self.fc = NoisyLinear(self.h +
                              add_features, self.h) if noise else nn.Linear(
                                  self.h + add_features, self.h)
        self.actionNet = ConstDiscrete(config, self.h, outDim, n_quant)

    def forward(self, flat, ents, eps=0, punish=None, val=None, device='cpu'):
        stim = self.envNet(flat.to(device), ents.to(device), device=device)
        if self.config.VAL_FEATURE and val is not None:
            stim = torch.cat([stim.to(device), val.to(device)], dim=1)
        if self.config.NOISE:
            x = F.relu(self.fc(stim.to(device), device=device))
        else:
            x = F.relu(self.fc(stim.to(device)))
        outs, idx = self.actionNet(x.to(device), eps, punish)
        return outs, idx

    def reset_noise(self):
        self.envNet.reset_noise()
        self.fc.reset_noise()
Beispiel #2
0
    def __init__(self, config, noise=False):
        super().__init__()
        self.config = config
        h = config.HIDDEN

        self.conv = nn.Conv2d(3, 6, (3, 3))
        self.fc = NoisyLinear(self.config.ENT_DIM, h) if noise else nn.Linear(
            self.config.ENT_DIM, h)
Beispiel #3
0
 def __init__(self, config, n_quant=1, noise=False):
     super().__init__()
     self.config = config
     self.h = config.HIDDEN
     self.fc = NoisyLinear(self.h, self.h) if noise else nn.Linear(
         self.h, self.h)
     self.valNet = nn.Linear(self.h, n_quant)
     self.n_quant = n_quant
Beispiel #4
0
    def __init__(self, config, noise=False):
        super().__init__()
        self.config = config
        h = config.HIDDEN
        entDim = config.ENT_DIM

        self.fc = NoisyLinear(2 * h, h) if noise else nn.Linear(2 * h, h)
        self.flat = nn.Linear(entDim, h)
        self.ents = Ent(entDim, h)
Beispiel #5
0
 def __init__(self,
              config,
              args,
              outDim=8,
              n_quant=1,
              add_features=0,
              noise=False):
     super().__init__()
     self.config, self.args, self.h = config, args, config.HIDDEN
     self.fc = NoisyLinear(self.h +
                           add_features, self.h) if noise else nn.Linear(
                               self.h + add_features, self.h)
     self.actionNet = ConstDiscrete(config, self.h, outDim, n_quant)
Beispiel #6
0
 def __init__(self,
              config,
              entDim=19,
              outDim=2,
              n_quant=1,
              add_features=0,
              noise=False):
     super().__init__()
     self.config, self.h = config, config.HIDDEN
     self.entDim = entDim
     self.envNet = Env(config, noise=noise)
     self.fc = NoisyLinear(self.h +
                           add_features, self.h) if noise else nn.Linear(
                               self.h + add_features, self.h)
     self.actionNet = ConstDiscrete(config, self.h, outDim, n_quant)
Beispiel #7
0
class Env(nn.Module):
    def __init__(self, config, noise=False):
        super().__init__()
        self.config = config
        h = config.HIDDEN

        self.conv = nn.Conv2d(3, 6, (3, 3))
        self.fc = NoisyLinear(self.config.ENT_DIM, h) if noise else nn.Linear(
            self.config.ENT_DIM, h)

    def forward(self, env):
        x = F.relu(self.conv(env).view(env.shape[0], -1))
        x = F.relu(self.fc(x))
        return x

    def reset_noise(self):
        self.fc.reset_noise()
Beispiel #8
0
class ValNet(nn.Module):
    def __init__(self, config, n_quant=1, noise=False):
        super().__init__()
        self.config = config
        self.h = config.HIDDEN
        self.fc = NoisyLinear(self.h, self.h) if noise else nn.Linear(
            self.h, self.h)
        self.valNet = nn.Linear(self.h, n_quant)
        self.n_quant = n_quant

    def forward(self, s):
        x = F.relu(self.fc(s))
        x = self.valNet(x).view(-1, 1, self.n_quant)
        return x

    def reset_noise(self):
        self.fc.reset_noise()
Beispiel #9
0
class Env(nn.Module):
    def __init__(self, config, noise=False):
        super().__init__()
        self.config = config
        h = config.HIDDEN
        entDim = config.ENT_DIM

        self.fc = NoisyLinear(2 * h, h) if noise else nn.Linear(2 * h, h)
        self.flat = nn.Linear(entDim, h)
        self.ents = Ent(entDim, h)

    def forward(self, flat, ents, device='cpu'):
        flat = self.flat(flat.to(device))
        ents = self.ents(ents.to(device), device=device)
        x = torch.cat((flat, ents), dim=1)
        if self.config.NOISE:
            x = F.relu(self.fc(x.to(device), device=device))
        else:
            x = F.relu(self.fc(x.to(device)))
        return x

    def reset_noise(self):
        self.fc.reset_noise()
Beispiel #10
0
class ActionNet(nn.Module):
    def __init__(self,
                 config,
                 args,
                 outDim=8,
                 n_quant=1,
                 add_features=0,
                 noise=False):
        super().__init__()
        self.config, self.args, self.h = config, args, config.HIDDEN
        self.fc = NoisyLinear(self.h +
                              add_features, self.h) if noise else nn.Linear(
                                  self.h + add_features, self.h)
        self.actionNet = ConstDiscrete(config, self.h, outDim, n_quant)

    def forward(self, s, eps=0, punish=None, val=None):
        if self.config.VAL_FEATURE and val is not None:
            s = torch.cat([s, val], dim=1)
        x = F.relu(self.fc(s))
        outs, idx = self.actionNet(x, eps, punish)
        return outs, idx

    def reset_noise(self):
        self.fc.reset_noise()