def __init__(self, dim_d, **kw):
     super(Discriminator, self).__init__(**kw)
     self.layers = torch.nn.ModuleList([
         ResBlock(3, dim_d, 3, resample='down', batnorm=False),
         ResBlock(dim_d, dim_d, 3, resample='down', batnorm=False),
         ResBlock(dim_d, dim_d, 3, resample=None, batnorm=False),
         ResBlock(dim_d, dim_d, 3, resample=None, batnorm=False),
         q.Lambda(lambda x: x.mean(3).mean(2)),      # global average pooling over spatial dims
         torch.nn.Linear(dim_d, 1),
         q.Lambda(lambda x: x.squeeze(1))
     ])
Ejemplo n.º 2
0
 def __init__(self, dim_d, use_bn=False, **kw):
     super(OldDiscriminator, self).__init__(**kw)
     self.layers = torch.nn.ModuleList([
         OptimizedResBlockDisc1(dim_d),
         ResidualBlock(dim_d, dim_d, 3, resample="down", use_bn=use_bn),
         ResidualBlock(dim_d, dim_d, 3, resample=None, use_bn=use_bn),
         ResidualBlock(dim_d, dim_d, 3, resample=None, use_bn=use_bn),
         torch.nn.ReLU(),
         q.Lambda(lambda x: x.mean(3).mean(2)),
         torch.nn.Linear(dim_d, 1),
         q.Lambda(lambda x: x.squeeze(1))
     ])
Ejemplo n.º 3
0
 def __init__(self, z_dim, dim_g, use_bn=True, **kw):
     super(OldGenerator, self).__init__(**kw)
     self.layers = torch.nn.ModuleList([
         torch.nn.Linear(z_dim, 4 * 4 * dim_g),
         q.Lambda(lambda x: x.view(x.size(0), dim_g, 4, 4)),
         ResidualBlock(dim_g, dim_g, 3, resample="up", use_bn=use_bn),
         ResidualBlock(dim_g, dim_g, 3, resample="up", use_bn=use_bn),
         ResidualBlock(dim_g, dim_g, 3, resample="up", use_bn=use_bn),
         Normalize(dim_g),
         torch.nn.ReLU(),
         torch.nn.Conv2d(dim_g, 3, kernel_size=3, padding=1),
         torch.nn.Tanh(),
     ])
Ejemplo n.º 4
0
 def __init__(self, z_dim, dim_g, **kw):
     super(Generator, self).__init__(**kw)
     self.layers = torch.nn.ModuleList([
         q.Lambda(lambda x: x.unsqueeze(2).unsqueeze(3)),
         torch.nn.ConvTranspose2d(z_dim, dim_g, 4),
         torch.nn.BatchNorm2d(dim_g),
         torch.nn.ReLU(),
         ResBlock(dim_g, dim_g, 3, resample='up'),
         ResBlock(dim_g, dim_g, 3, resample='up'),
         ResBlock(dim_g, dim_g, 3, resample='up'),
         torch.nn.Conv2d(dim_g, 3, 3, padding=1),
         torch.nn.Tanh(),
     ])