예제 #1
0
    def __init__(self):
        super().__init__()

        self.encoder = ModuleCompose(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            F.relu,
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
        )

        self.decoder = ModuleCompose(
            ConvPixelShuffle(64, 32, upscale_factor=2),
            F.relu,
            ConvPixelShuffle(32, 1, upscale_factor=2),
            lambda x: x[:, 0],
        )

        # Alternatives:
        # - RelaxedBernoulli - maybe doesn't work?
        # - RelaxedOneHotCategorical
        # - RelaxedOneHotCategorical * Codebook

        self.image = ModuleCompose(
            self.encoder,
            lambda logits: RelaxedBernoulli(
                temperature=0.5,
                logits=logits,
            ).rsample(),
            self.decoder,
        )
예제 #2
0
    def __init__(self):
        super().__init__()

        self.encoder = ModuleCompose(
            nn.Conv2d(1, 16, 3, stride=2, padding=1, bias=False),  # 16x16
            nn.BatchNorm2d(16),
            Swish(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            Swish(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            SqueezeExcitation(16),
            nn.Conv2d(16, 32, 3, stride=2, padding=1, bias=False),  # 8x8
            nn.BatchNorm2d(32),
            Swish(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            Swish(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            SqueezeExcitation(32),
            nn.Conv2d(32, 1, 3, stride=2, padding=1, bias=False),  # 4x4
            Swish(),
            nn.Flatten(),
            nn.Linear(1 * 4 * 4, 2 * 1 * 4 * 4),  # should be global
            lambda x: x.chunk(2, dim=1),
        )

        self.decoder = ModuleCompose(
            lambda x: x.view(-1, 16, 1, 1).expand(-1, 16, 4, 4),
            RandomFourier(16),
            nn.BatchNorm2d(32),
            Swish(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            Swish(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            SqueezeExcitation(32),
            ConvPixelShuffle(32, 32, upscale_factor=2),
            nn.BatchNorm2d(32),
            Swish(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            Swish(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            SqueezeExcitation(32),
            ConvPixelShuffle(32, 16, upscale_factor=2),
            nn.BatchNorm2d(16),
            Swish(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            Swish(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            SqueezeExcitation(16),
            ConvPixelShuffle(16, 1, upscale_factor=2),
            torch.sigmoid,
            lambda x: x[:, 0],
        )
예제 #3
0
    def __init__(self):
        super().__init__()

        self.real = ModuleCompose(
            nn.Conv2d(1, 16, 3, stride=2, padding=1, bias=False),  # 16x16
            nn.BatchNorm2d(16),
            Swish(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            Swish(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            SqueezeExcitation(16),
            nn.Conv2d(16, 32, 3, stride=2, padding=1, bias=False),  # 8x8
            nn.BatchNorm2d(32),
            Swish(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            Swish(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            SqueezeExcitation(32),
            nn.Conv2d(32, 1, 3, stride=2, padding=1, bias=False),  # 4x4
            Swish(),
            nn.Flatten(),
            nn.Linear(1 * 4 * 4, 1),
        )
예제 #4
0
    def __init__(self):
        super().__init__()

        size = 32

        self.image = ModuleCompose(
            lambda x: x.view(-1, 16, 1, 1).expand(-1, 16, 4, 4),
            RandomFourier(16),
            nn.Conv2d(16 + 16, size, kernel_size=1, bias=False),
            DecoderCell(size),
            ConvPixelShuffle(size, size, upscale_factor=2),
            DecoderCell(size),
            ConvPixelShuffle(size, size // 2, upscale_factor=2),
            DecoderCell(size // 2),
            ConvPixelShuffle(size // 2, size // 4, upscale_factor=2),
            # DecoderCell(size // 4),
            # nn.BatchNorm2d(size // 4),
            # Swish(),
            nn.Conv2d(size // 4,
                      size // 4,
                      kernel_size=3,
                      padding=1,
                      bias=False),
            nn.BatchNorm2d(size // 4),
            Swish(),
            nn.Conv2d(size // 4, 1, kernel_size=3, padding=1, bias=False),
            torch.tanh,
        )
예제 #5
0
    def __init__(self):
        super().__init__()

        size = 32

        self.latent = ModuleCompose(
            nn.Conv2d(1, size, 3, stride=2, padding=1, bias=False),  # 16x16
            EncoderCell(size),
            nn.Conv2d(size, 2 * size, 3, stride=2, padding=1, bias=False),  # 8x8
            EncoderCell(2 * size),
            nn.Conv2d(2 * size, 1, 3, stride=2, padding=1, bias=False),  # 4x4
            Swish(),
            nn.BatchNorm2d(1),
            nn.Flatten(),
            nn.Linear(1 * 4 * 4, 1 * 4 * 4),
        )
예제 #6
0
 def __init__(self):
     super().__init__()
     self.logits = ModuleCompose(
         nn.Conv2d(3, 32, 3, 1),
         F.relu,
         nn.Conv2d(32, 64, 3, 1),
         F.relu,
         partial(F.max_pool2d, kernel_size=2),
         nn.Dropout2d(0.25),
         partial(torch.flatten, start_dim=1),
         nn.Linear(12544, 128),
         F.relu,
         nn.Dropout2d(0.5),
         nn.Linear(128, 10),
         partial(F.log_softmax, dim=1),
     )
예제 #7
0
    def __init__(self):
        super().__init__()

        # transformer?

        # TODO: 16 layers
        # GaussianNoise, GaussianDropout
        # LayerNormalization
        # https://github.com/ConorLazarou/AEGAN-keras/blob/master/code/generative_model.py
        self.real = ModuleCompose(
            nn.Linear(1 * 4 * 4, 32),
            nn.LeakyReLU(0.02),
            nn.Linear(32, 32),
            nn.LeakyReLU(0.02),
            nn.Linear(32, 32),
            nn.LeakyReLU(0.02),
            nn.Linear(32, 1),
        )