def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[], no_const=False, fmap_max=512): super().__init__() self.image_size = image_size self.latent_dim = latent_dim self.num_layers = int(log2(image_size) - 1) filters = [ network_capacity * (2**(i + 1)) for i in range(self.num_layers) ][::-1] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) init_channels = filters[0] filters = [init_channels, *filters] in_out_pairs = zip(filters[:-1], filters[1:]) self.no_const = no_const if no_const: self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False) else: self.initial_block = nn.Parameter( torch.randn((1, init_channels, 4, 4))) self.blocks = nn.ModuleList([]) self.attns = nn.ModuleList([]) for ind, (in_chan, out_chan) in enumerate(in_out_pairs): not_first = ind != 0 not_last = ind != (self.num_layers - 1) num_layer = self.num_layers - ind attn_fn = nn.Sequential(*[ Residual(Rezero(ImageLinearAttention(in_chan))) for _ in range(2) ]) if num_layer in attn_layers else None self.attns.append(attn_fn) block = GeneratorBlock(latent_dim, in_chan, out_chan, upsample=not_first, upsample_rgb=not_last, rgba=transparent) self.blocks.append(block)
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], transparent=False, fmap_max=512): super().__init__() num_layers = int(log2(image_size) - 1) num_init_filters = 3 if not transparent else 4 blocks = [] filters = [num_init_filters] + [(network_capacity) * (2**i) for i in range(num_layers + 1)] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) chan_in_out = list(zip(filters[:-1], filters[1:])) blocks = [] quantize_blocks = [] attn_blocks = [] for ind, (in_chan, out_chan) in enumerate(chan_in_out): num_layer = ind + 1 is_not_last = ind != (len(chan_in_out) - 1) block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) blocks.append(block) attn_fn = nn.Sequential(*[ Residual(Rezero(ImageLinearAttention(out_chan))) for _ in range(2) ]) if num_layer in attn_layers else None attn_blocks.append(attn_fn) quantize_fn = PermuteToFrom(VectorQuantize( out_chan, fq_dict_size)) if num_layer in fq_layers else None quantize_blocks.append(quantize_fn) self.blocks = nn.ModuleList(blocks) self.attn_blocks = nn.ModuleList(attn_blocks) self.quantize_blocks = nn.ModuleList(quantize_blocks) latent_dim = 2 * 2 * filters[-1] self.flatten = Flatten() self.to_logit = nn.Linear(latent_dim, 1)
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, attn_layers=[]): super().__init__() self.image_size = image_size self.latent_dim = latent_dim self.num_layers = int(log2(image_size) - 1) init_channels = 4 * network_capacity self.initial_block = nn.Parameter(torch.randn((init_channels, 4, 4))) filters = [init_channels] + [ network_capacity * (2**(i + 1)) for i in range(self.num_layers) ][::-1] in_out_pairs = zip(filters[0:-1], filters[1:]) self.blocks = nn.ModuleList([]) self.attns = nn.ModuleList([]) for ind, (in_chan, out_chan) in enumerate(in_out_pairs): not_first = ind != 0 not_last = ind != (self.num_layers - 1) num_layer = self.num_layers - ind attn_fn = (nn.Sequential(*[ Residual(Rezero(ImageLinearAttention(in_chan))) for _ in range(2) ]) if num_layer in attn_layers else None) self.attns.append(attn_fn) block = GeneratorBlock(latent_dim, in_chan, out_chan, upsample=not_first, upsample_rgb=not_last, rgba=transparent) self.blocks.append(block)
class PermuteToFrom(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): x = x.permute(0, 2, 3, 1) out, loss = self.fn(x) out = out.permute(0, 3, 1, 2) return out, loss # one layer of self-attention and feedforward, for images attn_and_ff = lambda chan: nn.Sequential(*[ Residual(Rezero(ImageLinearAttention(chan))), Residual( Rezero( nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) ]) # helpers def default(value, d): return d if value is None else value def cycle(iterable): while True:
class Rezero(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn self.g = nn.Parameter(torch.zeros(1)) def forward(self, x): return self.fn(x) * self.g # one layer of self-attention and feedforward, for images attn_and_ff = lambda chan: nn.Sequential(*[ Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))), Residual( Rezero( nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) ]) # helpers def default(value, d): return d if value is None else value def cycle(iterable): while True: