def __init__(self): super().__init__() self.encoder = ModuleCompose( nn.Conv2d(1, 32, 3, stride=2, padding=1), F.relu, nn.Conv2d(32, 64, 3, stride=2, padding=1), ) self.decoder = ModuleCompose( ConvPixelShuffle(64, 32, upscale_factor=2), F.relu, ConvPixelShuffle(32, 1, upscale_factor=2), lambda x: x[:, 0], ) # Alternatives: # - RelaxedBernoulli - maybe doesn't work? # - RelaxedOneHotCategorical # - RelaxedOneHotCategorical * Codebook self.image = ModuleCompose( self.encoder, lambda logits: RelaxedBernoulli( temperature=0.5, logits=logits, ).rsample(), self.decoder, )
def __init__(self): super().__init__() self.encoder = ModuleCompose( nn.Conv2d(1, 16, 3, stride=2, padding=1, bias=False), # 16x16 nn.BatchNorm2d(16), Swish(), nn.Conv2d(16, 16, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(16), Swish(), nn.Conv2d(16, 16, kernel_size=3, padding=1), SqueezeExcitation(16), nn.Conv2d(16, 32, 3, stride=2, padding=1, bias=False), # 8x8 nn.BatchNorm2d(32), Swish(), nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), Swish(), nn.Conv2d(32, 32, kernel_size=3, padding=1), SqueezeExcitation(32), nn.Conv2d(32, 1, 3, stride=2, padding=1, bias=False), # 4x4 Swish(), nn.Flatten(), nn.Linear(1 * 4 * 4, 2 * 1 * 4 * 4), # should be global lambda x: x.chunk(2, dim=1), ) self.decoder = ModuleCompose( lambda x: x.view(-1, 16, 1, 1).expand(-1, 16, 4, 4), RandomFourier(16), nn.BatchNorm2d(32), Swish(), nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), Swish(), nn.Conv2d(32, 32, kernel_size=3, padding=1), SqueezeExcitation(32), ConvPixelShuffle(32, 32, upscale_factor=2), nn.BatchNorm2d(32), Swish(), nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), Swish(), nn.Conv2d(32, 32, kernel_size=3, padding=1), SqueezeExcitation(32), ConvPixelShuffle(32, 16, upscale_factor=2), nn.BatchNorm2d(16), Swish(), nn.Conv2d(16, 16, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(16), Swish(), nn.Conv2d(16, 16, kernel_size=3, padding=1), SqueezeExcitation(16), ConvPixelShuffle(16, 1, upscale_factor=2), torch.sigmoid, lambda x: x[:, 0], )
def __init__(self): super().__init__() self.real = ModuleCompose( nn.Conv2d(1, 16, 3, stride=2, padding=1, bias=False), # 16x16 nn.BatchNorm2d(16), Swish(), nn.Conv2d(16, 16, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(16), Swish(), nn.Conv2d(16, 16, kernel_size=3, padding=1), SqueezeExcitation(16), nn.Conv2d(16, 32, 3, stride=2, padding=1, bias=False), # 8x8 nn.BatchNorm2d(32), Swish(), nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), Swish(), nn.Conv2d(32, 32, kernel_size=3, padding=1), SqueezeExcitation(32), nn.Conv2d(32, 1, 3, stride=2, padding=1, bias=False), # 4x4 Swish(), nn.Flatten(), nn.Linear(1 * 4 * 4, 1), )
def __init__(self): super().__init__() size = 32 self.image = ModuleCompose( lambda x: x.view(-1, 16, 1, 1).expand(-1, 16, 4, 4), RandomFourier(16), nn.Conv2d(16 + 16, size, kernel_size=1, bias=False), DecoderCell(size), ConvPixelShuffle(size, size, upscale_factor=2), DecoderCell(size), ConvPixelShuffle(size, size // 2, upscale_factor=2), DecoderCell(size // 2), ConvPixelShuffle(size // 2, size // 4, upscale_factor=2), # DecoderCell(size // 4), # nn.BatchNorm2d(size // 4), # Swish(), nn.Conv2d(size // 4, size // 4, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(size // 4), Swish(), nn.Conv2d(size // 4, 1, kernel_size=3, padding=1, bias=False), torch.tanh, )
def __init__(self): super().__init__() size = 32 self.latent = ModuleCompose( nn.Conv2d(1, size, 3, stride=2, padding=1, bias=False), # 16x16 EncoderCell(size), nn.Conv2d(size, 2 * size, 3, stride=2, padding=1, bias=False), # 8x8 EncoderCell(2 * size), nn.Conv2d(2 * size, 1, 3, stride=2, padding=1, bias=False), # 4x4 Swish(), nn.BatchNorm2d(1), nn.Flatten(), nn.Linear(1 * 4 * 4, 1 * 4 * 4), )
def __init__(self): super().__init__() self.logits = ModuleCompose( nn.Conv2d(3, 32, 3, 1), F.relu, nn.Conv2d(32, 64, 3, 1), F.relu, partial(F.max_pool2d, kernel_size=2), nn.Dropout2d(0.25), partial(torch.flatten, start_dim=1), nn.Linear(12544, 128), F.relu, nn.Dropout2d(0.5), nn.Linear(128, 10), partial(F.log_softmax, dim=1), )
def __init__(self): super().__init__() # transformer? # TODO: 16 layers # GaussianNoise, GaussianDropout # LayerNormalization # https://github.com/ConorLazarou/AEGAN-keras/blob/master/code/generative_model.py self.real = ModuleCompose( nn.Linear(1 * 4 * 4, 32), nn.LeakyReLU(0.02), nn.Linear(32, 32), nn.LeakyReLU(0.02), nn.Linear(32, 32), nn.LeakyReLU(0.02), nn.Linear(32, 1), )