def __init__(self, *, image_size, fmap_max=512, fmap_inverse_coef=12, transparent=False, greyscale=False, disc_output_size=5, attn_res_layers=[]): super().__init__() resolution = log2(image_size) assert is_power_of_two(image_size), 'image size must be a power of 2' assert disc_output_size in { 1, 5 }, 'discriminator output dimensions can only be 5x5 or 1x1' resolution = int(resolution) if transparent: init_channel = 4 elif greyscale: init_channel = 1 else: init_channel = 3 num_non_residual_layers = max(0, int(resolution) - 8) num_residual_layers = 8 - 3 non_residual_resolutions = range(min(8, resolution), 2, -1) features = list( map(lambda n: (n, 2**(fmap_inverse_coef - n)), non_residual_resolutions)) features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) if num_non_residual_layers == 0: res, _ = features[0] features[0] = (res, init_channel) chan_in_out = list(zip(features[:-1], features[1:])) self.non_residual_layers = nn.ModuleList([]) for ind in range(num_non_residual_layers): first_layer = ind == 0 last_layer = ind == (num_non_residual_layers - 1) chan_out = features[0][-1] if last_layer else init_channel self.non_residual_layers.append( nn.Sequential( Blur(), nn.Conv2d(init_channel, chan_out, 4, stride=2, padding=1), nn.LeakyReLU(0.1))) self.residual_layers = nn.ModuleList([]) for (res, ((_, chan_in), (_, chan_out))) in zip(non_residual_resolutions, chan_in_out): image_width = 2**resolution attn = None if image_width in attn_res_layers: attn = Rezero( GSA(dim=chan_in, batch_norm=False, norm_queries=True)) self.residual_layers.append( nn.ModuleList([ SumBranches([ nn.Sequential( Blur(), nn.Conv2d(chan_in, chan_out, 4, stride=2, padding=1), nn.LeakyReLU(0.1), nn.Conv2d(chan_out, chan_out, 3, padding=1), nn.LeakyReLU(0.1)), nn.Sequential( Blur(), nn.AvgPool2d(2), nn.Conv2d(chan_in, chan_out, 1), nn.LeakyReLU(0.1), ) ]), attn ])) last_chan = features[-1][-1] if disc_output_size == 5: self.to_logits = nn.Sequential(nn.Conv2d(last_chan, last_chan, 1), nn.LeakyReLU(0.1), nn.Conv2d(last_chan, 1, 4)) elif disc_output_size == 1: self.to_logits = nn.Sequential( Blur(), nn.Conv2d(last_chan, last_chan, 3, stride=2, padding=1), nn.LeakyReLU(0.1), nn.Conv2d(last_chan, 1, 4)) self.to_shape_disc_out = nn.Sequential( nn.Conv2d(init_channel, 64, 3, padding=1), Residual(Rezero(GSA(dim=64, norm_queries=True, batch_norm=False))), SumBranches([ nn.Sequential(Blur(), nn.Conv2d(64, 32, 4, stride=2, padding=1), nn.LeakyReLU(0.1), nn.Conv2d(32, 32, 3, padding=1), nn.LeakyReLU(0.1)), nn.Sequential( Blur(), nn.AvgPool2d(2), nn.Conv2d(64, 32, 1), nn.LeakyReLU(0.1), ) ]), Residual(Rezero(GSA(dim=32, norm_queries=True, batch_norm=False))), nn.AdaptiveAvgPool2d((4, 4)), nn.Conv2d(32, 1, 4)) self.decoder1 = SimpleDecoder(chan_in=last_chan, chan_out=init_channel) self.decoder2 = SimpleDecoder( chan_in=features[-2][-1], chan_out=init_channel) if resolution >= 9 else None
def __init__( self, *, image_size, fmap_max=512, fmap_inverse_coef=12, num_chans=3, disc_output_size=5, attn_res_layers=[], num_classes=0, bn4decoder=True, ): super().__init__() resolution = log2(image_size) assert is_power_of_two(image_size), 'image size must be a power of 2' assert disc_output_size in { 1, 5 }, 'discriminator output dimensions can only be 5x5 or 1x1' resolution = int(resolution) init_channel = num_chans num_non_residual_layers = max(0, int(resolution) - 8) num_residual_layers = 8 - 3 non_residual_resolutions = range(min(8, resolution), 2, -1) features = list( map(lambda n: (n, 2**(fmap_inverse_coef - n)), non_residual_resolutions)) features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) if num_non_residual_layers == 0: res, _ = features[0] features[0] = (res, init_channel) chan_in_out = list(zip(features[:-1], features[1:])) self.non_residual_layers = nn.ModuleList([]) for ind in range(num_non_residual_layers): first_layer = ind == 0 last_layer = ind == (num_non_residual_layers - 1) chan_out = features[0][-1] if last_layer else init_channel self.non_residual_layers.append( nn.Sequential( Blur(), nn.Conv2d(init_channel, chan_out, 4, stride=2, padding=1), nn.LeakyReLU(0.1))) self.residual_layers = nn.ModuleList([]) for (res, ((_, chan_in), (_, chan_out))) in zip(non_residual_resolutions, chan_in_out): image_width = 2**res # res vs resolution attn = None if image_width in attn_res_layers: attn = Rezero( GSA(dim=chan_in, batch_norm=False, norm_queries=True)) self.residual_layers.append( nn.ModuleList([ SumBranches([ nn.Sequential( Blur(), nn.Conv2d(chan_in, chan_out, 4, stride=2, padding=1), nn.LeakyReLU(0.1), nn.Conv2d(chan_out, chan_out, 3, padding=1), nn.LeakyReLU(0.1)), nn.Sequential( Blur(), nn.AvgPool2d(2), nn.Conv2d(chan_in, chan_out, 1), nn.LeakyReLU(0.1), ) ]), attn ])) last_chan = features[-1][-1] if disc_output_size == 5: raise NotImplementedError # though on projection self.to_pre_logits = nn.Sequential( nn.Conv2d(last_chan, last_chan, 1), nn.LeakyReLU(0.1), nn.Conv2d(last_chan, 1, 4)) elif disc_output_size == 1: self.to_pre_logits = nn.Sequential( Blur(), nn.Conv2d(last_chan, last_chan, 3, stride=2, padding=1), nn.LeakyReLU(0.1), nn.Conv2d(last_chan, last_chan, 4), nn.LeakyReLU(0.1)) self.out = nn.Linear(last_chan, 1) self.to_shape_disc_out = nn.Sequential( nn.Conv2d(init_channel, 64, 3, padding=1), Residual(Rezero(GSA(dim=64, norm_queries=True, batch_norm=False))), SumBranches([ nn.Sequential(Blur(), nn.Conv2d(64, 32, 4, stride=2, padding=1), nn.LeakyReLU(0.1), nn.Conv2d(32, 32, 3, padding=1), nn.LeakyReLU(0.1)), nn.Sequential( Blur(), nn.AvgPool2d(2), nn.Conv2d(64, 32, 1), nn.LeakyReLU(0.1), ) ]), Residual(Rezero(GSA(dim=32, norm_queries=True, batch_norm=False))), nn.AdaptiveAvgPool2d((4, 4)), nn.Conv2d(32, 1, 4)) self.decoder1 = SimpleDecoder(chan_in=last_chan, chan_out=init_channel, end_glu=bn4decoder) self.decoder2 = SimpleDecoder( chan_in=features[-2][-1], chan_out=init_channel, end_glu=bn4decoder) if resolution >= 9 else None if num_classes > 0: self.l_y = nn.utils.spectral_norm( nn.Embedding(num_classes, last_chan)) self._initialize() self.bn4decoder = nn.BatchNorm2d( num_chans) if bn4decoder else nn.Identity( ) # GLU enforces not affine
def __init__(self, *, image_size, latent_dim=256, fmap_max=512, fmap_inverse_coef=12, transparent=False, greyscale=False, attn_res_layers=[]): super().__init__() resolution = log2(image_size) assert is_power_of_two(image_size), 'image size must be a power of 2' if transparent: init_channel = 4 elif greyscale: init_channel = 1 else: init_channel = 3 fmap_max = default(fmap_max, latent_dim) self.initial_conv = nn.Sequential( nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4), norm_class(latent_dim * 2), nn.GLU(dim=1)) num_layers = int(resolution) - 2 features = list( map(lambda n: (n, 2**(fmap_inverse_coef - n)), range(2, num_layers + 2))) features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features)) features = [latent_dim, *features] in_out_features = list(zip(features[:-1], features[1:])) self.res_layers = range(2, num_layers + 2) self.layers = nn.ModuleList([]) self.res_to_feature_map = dict(zip(self.res_layers, in_out_features)) self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10)) self.sle_map = list( filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)) self.sle_map = dict(self.sle_map) self.num_layers_spatial_res = 1 for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features): image_width = 2**res attn = None if image_width in attn_res_layers: attn = Rezero(GSA(dim=chan_in, norm_queries=True)) sle = None if res in self.sle_map: residual_layer = self.sle_map[res] sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1] sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out) layer = nn.ModuleList([ nn.Sequential(upsample(), Blur(), nn.Conv2d(chan_in, chan_out * 2, 3, padding=1), norm_class(chan_out * 2), nn.GLU(dim=1)), sle, attn ]) self.layers.append(layer) self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
def __init__( self, *, image_size, latent_dim=256, fmap_max=512, fmap_inverse_coef=12, num_chans=3, attn_res_layers=[], freq_chan_attn=False, num_classes=0, cat_res_layers=[], embedding_dim=16, ): super().__init__() assert num_classes > 0 or cat_res_layers == [] resolution = log2(image_size) assert is_power_of_two(image_size), 'image size must be a power of 2' init_channel = num_chans fmap_max = default(fmap_max, latent_dim) self.init_conv = InitConv(latent_dim, 0) # HERE self.num_classes = num_classes num_layers = int(resolution) - 2 features = list( map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))) features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features)) # TODO: should it be num chans? features = [latent_dim, *features] in_out_features = list(zip(features[:-1], features[1:])) self.res_layers = range(2, num_layers + 2) self.layers = nn.ModuleList([]) self.res_to_feature_map = dict(zip(self.res_layers, in_out_features)) self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10)) self.sle_map = list( filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)) self.sle_map = dict(self.sle_map) self.num_layers_spatial_res = 1 for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features): image_width = 2 ** res cat = Catter(chan_in, embedding_dim, num_classes) if image_width in cat_res_layers else None attn = None if image_width in attn_res_layers: attn = Rezero(GSA(dim=chan_in, norm_queries=True)) sle = None if res in self.sle_map: residual_layer = self.sle_map[res] sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1] if freq_chan_attn: sle = FCANet( chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1) ) else: sle = GlobalContext( chan_in=chan_out, chan_out=sle_chan_out ) layer = nn.ModuleList([ cat, nn.Sequential( upsample(), Blur(), nn.Conv2d(chan_in, chan_out * 2, 3, padding=1), norm_class(chan_out * 2), nn.GLU(dim=1) ), sle, attn ]) self.layers.append(layer) self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)