Ejemplo n.º 1
0
    def __init__(self,
                 *,
                 image_size,
                 fmap_max=512,
                 fmap_inverse_coef=12,
                 transparent=False,
                 greyscale=False,
                 disc_output_size=5,
                 attn_res_layers=[]):
        super().__init__()
        resolution = log2(image_size)
        assert is_power_of_two(image_size), 'image size must be a power of 2'
        assert disc_output_size in {
            1, 5
        }, 'discriminator output dimensions can only be 5x5 or 1x1'

        resolution = int(resolution)

        if transparent:
            init_channel = 4
        elif greyscale:
            init_channel = 1
        else:
            init_channel = 3

        num_non_residual_layers = max(0, int(resolution) - 8)
        num_residual_layers = 8 - 3

        non_residual_resolutions = range(min(8, resolution), 2, -1)
        features = list(
            map(lambda n: (n, 2**(fmap_inverse_coef - n)),
                non_residual_resolutions))
        features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))

        if num_non_residual_layers == 0:
            res, _ = features[0]
            features[0] = (res, init_channel)

        chan_in_out = list(zip(features[:-1], features[1:]))

        self.non_residual_layers = nn.ModuleList([])
        for ind in range(num_non_residual_layers):
            first_layer = ind == 0
            last_layer = ind == (num_non_residual_layers - 1)
            chan_out = features[0][-1] if last_layer else init_channel

            self.non_residual_layers.append(
                nn.Sequential(
                    Blur(),
                    nn.Conv2d(init_channel, chan_out, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.1)))

        self.residual_layers = nn.ModuleList([])

        for (res, ((_, chan_in), (_,
                                  chan_out))) in zip(non_residual_resolutions,
                                                     chan_in_out):
            image_width = 2**resolution

            attn = None
            if image_width in attn_res_layers:
                attn = Rezero(
                    GSA(dim=chan_in, batch_norm=False, norm_queries=True))

            self.residual_layers.append(
                nn.ModuleList([
                    SumBranches([
                        nn.Sequential(
                            Blur(),
                            nn.Conv2d(chan_in,
                                      chan_out,
                                      4,
                                      stride=2,
                                      padding=1), nn.LeakyReLU(0.1),
                            nn.Conv2d(chan_out, chan_out, 3, padding=1),
                            nn.LeakyReLU(0.1)),
                        nn.Sequential(
                            Blur(),
                            nn.AvgPool2d(2),
                            nn.Conv2d(chan_in, chan_out, 1),
                            nn.LeakyReLU(0.1),
                        )
                    ]), attn
                ]))

        last_chan = features[-1][-1]
        if disc_output_size == 5:
            self.to_logits = nn.Sequential(nn.Conv2d(last_chan, last_chan, 1),
                                           nn.LeakyReLU(0.1),
                                           nn.Conv2d(last_chan, 1, 4))
        elif disc_output_size == 1:
            self.to_logits = nn.Sequential(
                Blur(), nn.Conv2d(last_chan, last_chan, 3, stride=2,
                                  padding=1), nn.LeakyReLU(0.1),
                nn.Conv2d(last_chan, 1, 4))

        self.to_shape_disc_out = nn.Sequential(
            nn.Conv2d(init_channel, 64, 3, padding=1),
            Residual(Rezero(GSA(dim=64, norm_queries=True, batch_norm=False))),
            SumBranches([
                nn.Sequential(Blur(), nn.Conv2d(64, 32, 4,
                                                stride=2, padding=1),
                              nn.LeakyReLU(0.1), nn.Conv2d(32,
                                                           32,
                                                           3,
                                                           padding=1),
                              nn.LeakyReLU(0.1)),
                nn.Sequential(
                    Blur(),
                    nn.AvgPool2d(2),
                    nn.Conv2d(64, 32, 1),
                    nn.LeakyReLU(0.1),
                )
            ]),
            Residual(Rezero(GSA(dim=32, norm_queries=True, batch_norm=False))),
            nn.AdaptiveAvgPool2d((4, 4)), nn.Conv2d(32, 1, 4))

        self.decoder1 = SimpleDecoder(chan_in=last_chan, chan_out=init_channel)
        self.decoder2 = SimpleDecoder(
            chan_in=features[-2][-1],
            chan_out=init_channel) if resolution >= 9 else None
Ejemplo n.º 2
0
    def __init__(
        self,
        *,
        image_size,
        fmap_max=512,
        fmap_inverse_coef=12,
        num_chans=3,
        disc_output_size=5,
        attn_res_layers=[],
        num_classes=0,
        bn4decoder=True,
    ):
        super().__init__()
        resolution = log2(image_size)
        assert is_power_of_two(image_size), 'image size must be a power of 2'
        assert disc_output_size in {
            1, 5
        }, 'discriminator output dimensions can only be 5x5 or 1x1'

        resolution = int(resolution)

        init_channel = num_chans

        num_non_residual_layers = max(0, int(resolution) - 8)
        num_residual_layers = 8 - 3

        non_residual_resolutions = range(min(8, resolution), 2, -1)
        features = list(
            map(lambda n: (n, 2**(fmap_inverse_coef - n)),
                non_residual_resolutions))
        features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))

        if num_non_residual_layers == 0:
            res, _ = features[0]
            features[0] = (res, init_channel)

        chan_in_out = list(zip(features[:-1], features[1:]))

        self.non_residual_layers = nn.ModuleList([])
        for ind in range(num_non_residual_layers):
            first_layer = ind == 0
            last_layer = ind == (num_non_residual_layers - 1)
            chan_out = features[0][-1] if last_layer else init_channel

            self.non_residual_layers.append(
                nn.Sequential(
                    Blur(),
                    nn.Conv2d(init_channel, chan_out, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.1)))

        self.residual_layers = nn.ModuleList([])

        for (res, ((_, chan_in), (_,
                                  chan_out))) in zip(non_residual_resolutions,
                                                     chan_in_out):
            image_width = 2**res  # res vs resolution

            attn = None
            if image_width in attn_res_layers:
                attn = Rezero(
                    GSA(dim=chan_in, batch_norm=False, norm_queries=True))

            self.residual_layers.append(
                nn.ModuleList([
                    SumBranches([
                        nn.Sequential(
                            Blur(),
                            nn.Conv2d(chan_in,
                                      chan_out,
                                      4,
                                      stride=2,
                                      padding=1), nn.LeakyReLU(0.1),
                            nn.Conv2d(chan_out, chan_out, 3, padding=1),
                            nn.LeakyReLU(0.1)),
                        nn.Sequential(
                            Blur(),
                            nn.AvgPool2d(2),
                            nn.Conv2d(chan_in, chan_out, 1),
                            nn.LeakyReLU(0.1),
                        )
                    ]), attn
                ]))

        last_chan = features[-1][-1]
        if disc_output_size == 5:
            raise NotImplementedError  # though on projection
            self.to_pre_logits = nn.Sequential(
                nn.Conv2d(last_chan, last_chan, 1), nn.LeakyReLU(0.1),
                nn.Conv2d(last_chan, 1, 4))
        elif disc_output_size == 1:
            self.to_pre_logits = nn.Sequential(
                Blur(), nn.Conv2d(last_chan, last_chan, 3, stride=2,
                                  padding=1), nn.LeakyReLU(0.1),
                nn.Conv2d(last_chan, last_chan, 4), nn.LeakyReLU(0.1))
            self.out = nn.Linear(last_chan, 1)

        self.to_shape_disc_out = nn.Sequential(
            nn.Conv2d(init_channel, 64, 3, padding=1),
            Residual(Rezero(GSA(dim=64, norm_queries=True, batch_norm=False))),
            SumBranches([
                nn.Sequential(Blur(), nn.Conv2d(64, 32, 4,
                                                stride=2, padding=1),
                              nn.LeakyReLU(0.1), nn.Conv2d(32,
                                                           32,
                                                           3,
                                                           padding=1),
                              nn.LeakyReLU(0.1)),
                nn.Sequential(
                    Blur(),
                    nn.AvgPool2d(2),
                    nn.Conv2d(64, 32, 1),
                    nn.LeakyReLU(0.1),
                )
            ]),
            Residual(Rezero(GSA(dim=32, norm_queries=True, batch_norm=False))),
            nn.AdaptiveAvgPool2d((4, 4)), nn.Conv2d(32, 1, 4))

        self.decoder1 = SimpleDecoder(chan_in=last_chan,
                                      chan_out=init_channel,
                                      end_glu=bn4decoder)
        self.decoder2 = SimpleDecoder(
            chan_in=features[-2][-1],
            chan_out=init_channel,
            end_glu=bn4decoder) if resolution >= 9 else None

        if num_classes > 0:
            self.l_y = nn.utils.spectral_norm(
                nn.Embedding(num_classes, last_chan))
        self._initialize()

        self.bn4decoder = nn.BatchNorm2d(
            num_chans) if bn4decoder else nn.Identity(
            )  # GLU enforces not affine
Ejemplo n.º 3
0
    def __init__(self,
                 *,
                 image_size,
                 latent_dim=256,
                 fmap_max=512,
                 fmap_inverse_coef=12,
                 transparent=False,
                 greyscale=False,
                 attn_res_layers=[]):
        super().__init__()
        resolution = log2(image_size)
        assert is_power_of_two(image_size), 'image size must be a power of 2'

        if transparent:
            init_channel = 4
        elif greyscale:
            init_channel = 1
        else:
            init_channel = 3

        fmap_max = default(fmap_max, latent_dim)

        self.initial_conv = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
            norm_class(latent_dim * 2), nn.GLU(dim=1))

        num_layers = int(resolution) - 2
        features = list(
            map(lambda n: (n, 2**(fmap_inverse_coef - n)),
                range(2, num_layers + 2)))
        features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
        features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
        features = [latent_dim, *features]

        in_out_features = list(zip(features[:-1], features[1:]))

        self.res_layers = range(2, num_layers + 2)
        self.layers = nn.ModuleList([])
        self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))

        self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
        self.sle_map = list(
            filter(lambda t: t[0] <= resolution and t[1] <= resolution,
                   self.sle_map))
        self.sle_map = dict(self.sle_map)

        self.num_layers_spatial_res = 1

        for (res, (chan_in, chan_out)) in zip(self.res_layers,
                                              in_out_features):
            image_width = 2**res

            attn = None
            if image_width in attn_res_layers:
                attn = Rezero(GSA(dim=chan_in, norm_queries=True))

            sle = None
            if res in self.sle_map:
                residual_layer = self.sle_map[res]
                sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]

                sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)

            layer = nn.ModuleList([
                nn.Sequential(upsample(), Blur(),
                              nn.Conv2d(chan_in, chan_out * 2, 3, padding=1),
                              norm_class(chan_out * 2), nn.GLU(dim=1)), sle,
                attn
            ])
            self.layers.append(layer)

        self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
Ejemplo n.º 4
0
    def __init__(
        self,
        *,
        image_size,
        latent_dim=256,
        fmap_max=512,
        fmap_inverse_coef=12,
        num_chans=3,
        attn_res_layers=[],
        freq_chan_attn=False,
        num_classes=0,
        cat_res_layers=[],
        embedding_dim=16,
    ):
        super().__init__()
        assert num_classes > 0 or cat_res_layers == []
        resolution = log2(image_size)
        assert is_power_of_two(image_size), 'image size must be a power of 2'

        init_channel = num_chans

        fmap_max = default(fmap_max, latent_dim)

        self.init_conv = InitConv(latent_dim, 0)  # HERE
        self.num_classes = num_classes

        num_layers = int(resolution) - 2
        features = list(
            map(lambda n: (n,  2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2)))
        features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
        features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))  # TODO: should it be num chans?
        features = [latent_dim, *features]

        in_out_features = list(zip(features[:-1], features[1:]))

        self.res_layers = range(2, num_layers + 2)
        self.layers = nn.ModuleList([])
        self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))

        self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
        self.sle_map = list(
            filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map))
        self.sle_map = dict(self.sle_map)

        self.num_layers_spatial_res = 1
        for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features):
            image_width = 2 ** res

            cat = Catter(chan_in, embedding_dim, num_classes) if image_width in cat_res_layers else None
            
            
            attn = None
            if image_width in attn_res_layers:
                attn = Rezero(GSA(dim=chan_in, norm_queries=True))
                

            sle = None
            if res in self.sle_map:
                residual_layer = self.sle_map[res]
                sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]

                if freq_chan_attn:
                    sle = FCANet(
                        chan_in=chan_out,
                        chan_out=sle_chan_out,
                        width=2 ** (res + 1)
                    )
                else:
                    sle = GlobalContext(
                        chan_in=chan_out,
                        chan_out=sle_chan_out
                    )

            layer = nn.ModuleList([
                cat, 
                nn.Sequential(
                    upsample(),
                    Blur(),
                    nn.Conv2d(chan_in, chan_out * 2, 3, padding=1),
                    norm_class(chan_out * 2),
                    nn.GLU(dim=1)
                ),
                sle,
                attn
            ])
            self.layers.append(layer)

        self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)