def __init__(self, num_layers=4, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], transparent=False, fmap_max=512): super().__init__() n_downsample = 4 num_layers = n_downsample + 1 network_capacity = 16 * 2 # fpmax = 512 so the max filter size is clipped to 512 # num_layers = int(log2(image_size) - 1) num_init_filters = 6 if not transparent else 8 blocks = [] filters = [num_init_filters] + [(network_capacity * 4) * (2**i) for i in range(num_layers + 1)] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) chan_in_out = list(zip(filters[:-1], filters[1:])) blocks = [] attn_blocks = [] quantize_blocks = [] for ind, (in_chan, out_chan) in enumerate(chan_in_out): num_layer = ind + 1 is_not_last = ind != (len(chan_in_out) - 1) block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) blocks.append(block) attn_fn = attn_and_ff( out_chan) if num_layer in attn_layers else None attn_blocks.append(attn_fn) quantize_fn = PermuteToFrom(VectorQuantize( out_chan, fq_dict_size)) if num_layer in fq_layers else None quantize_blocks.append(quantize_fn) self.blocks = nn.ModuleList(blocks) self.attn_blocks = nn.ModuleList(attn_blocks) self.quantize_blocks = nn.ModuleList(quantize_blocks) chan_last = filters[-1] latent_dim = 2 * 2 * chan_last self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.sigmoid = nn.Sigmoid()
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], transparent=False, fmap_max=512): super().__init__() num_layers = int(log2(image_size) - 1) num_init_filters = 3 if not transparent else 4 blocks = [] filters = [num_init_filters] + [(network_capacity) * (2**i) for i in range(num_layers + 1)] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) chan_in_out = list(zip(filters[:-1], filters[1:])) blocks = [] quantize_blocks = [] attn_blocks = [] for ind, (in_chan, out_chan) in enumerate(chan_in_out): num_layer = ind + 1 is_not_last = ind != (len(chan_in_out) - 1) block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) blocks.append(block) attn_fn = nn.Sequential(*[ Residual(Rezero(ImageLinearAttention(out_chan))) for _ in range(2) ]) if num_layer in attn_layers else None attn_blocks.append(attn_fn) quantize_fn = PermuteToFrom(VectorQuantize( out_chan, fq_dict_size)) if num_layer in fq_layers else None quantize_blocks.append(quantize_fn) self.blocks = nn.ModuleList(blocks) self.attn_blocks = nn.ModuleList(attn_blocks) self.quantize_blocks = nn.ModuleList(quantize_blocks) latent_dim = 2 * 2 * filters[-1] self.flatten = Flatten() self.to_logit = nn.Linear(latent_dim, 1)
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False, mlp=False, transfer_mode=False): super().__init__() num_layers = int(log2(image_size) - 1) blocks = [] filters = [input_filters] + [(64) * (2**i) for i in range(num_layers + 1)] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) chan_in_out = list(zip(filters[:-1], filters[1:])) blocks = [] attn_blocks = [] quantize_blocks = [] for ind, (in_chan, out_chan) in enumerate(chan_in_out): num_layer = ind + 1 is_not_last = ind != (len(chan_in_out) - 1) block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last, transfer_mode=transfer_mode) blocks.append(block) attn_fn = attn_and_ff( out_chan) if num_layer in attn_layers else None attn_blocks.append(attn_fn) if quantize: quantize_fn = PermuteToFrom( VectorQuantize( out_chan, fq_dict_size)) if num_layer in fq_layers else None quantize_blocks.append(quantize_fn) else: quantize_blocks.append(None) self.blocks = nn.ModuleList(blocks) self.attn_blocks = nn.ModuleList(attn_blocks) self.quantize_blocks = nn.ModuleList(quantize_blocks) self.do_checkpointing = do_checkpointing chan_last = filters[-1] latent_dim = 2 * 2 * chan_last self.final_conv = TransferConv2d(chan_last, chan_last, 3, padding=1, transfer_mode=transfer_mode) self.flatten = nn.Flatten() if mlp: self.to_logit = nn.Sequential( TransferLinear(latent_dim, 100, transfer_mode=transfer_mode), leaky_relu(), TransferLinear(100, 1, transfer_mode=transfer_mode)) else: self.to_logit = TransferLinear(latent_dim, 1, transfer_mode=transfer_mode) self._init_weights() self.transfer_mode = transfer_mode if transfer_mode: for p in self.parameters(): if not hasattr(p, 'FOR_TRANSFER_LEARNING'): p.DO_NOT_TRAIN = True
def __init__(self, image_size=256, num_tokens=512, codebook_dim=512, num_layers=3, num_resnet_blocks=0, hidden_dim=64, channels=3, smooth_l1_loss=False, vq_decay=0.8, commitment_weight=1.): super().__init__() assert log2(image_size).is_integer(), 'image size must be a power of 2' assert num_layers >= 1, 'number of layers must be greater than or equal to 1' has_resblocks = num_resnet_blocks > 0 self.image_size = image_size self.num_tokens = num_tokens self.num_layers = num_layers self.vq = VectorQuantize(dim=codebook_dim, n_embed=num_tokens, decay=vq_decay, commitment=commitment_weight) hdim = hidden_dim enc_chans = [hidden_dim] * num_layers dec_chans = list(reversed(enc_chans)) enc_chans = [channels, *enc_chans] dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] dec_chans = [dec_init_chan, *dec_chans] enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) enc_layers = [] dec_layers = [] for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): enc_layers.append( nn.Sequential( nn.Conv2d(enc_in, enc_out, 4, stride=2, padding=1), nn.ReLU())) dec_layers.append( nn.Sequential( nn.ConvTranspose2d(dec_in, dec_out, 4, stride=2, padding=1), nn.ReLU())) for _ in range(num_resnet_blocks): dec_layers.insert(0, ResBlock(dec_chans[1])) enc_layers.append(ResBlock(enc_chans[-1])) if num_resnet_blocks > 0: dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1)) enc_layers.append(nn.Conv2d(enc_chans[-1], codebook_dim, 1)) dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1)) self.encoder = nn.Sequential(*enc_layers) self.decoder = nn.Sequential(*dec_layers) self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], transparent=False, fmap_max=512, input_filters=3, quantize=False, do_checkpointing=False): super().__init__() num_layers = int(log2(image_size) - 1) blocks = [] filters = [input_filters] + [(64) * (2**i) for i in range(num_layers + 1)] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) chan_in_out = list(zip(filters[:-1], filters[1:])) blocks = [] attn_blocks = [] quantize_blocks = [] for ind, (in_chan, out_chan) in enumerate(chan_in_out): num_layer = ind + 1 is_not_last = ind != (len(chan_in_out) - 1) block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) blocks.append(block) attn_fn = attn_and_ff( out_chan) if num_layer in attn_layers else None attn_blocks.append(attn_fn) if quantize: quantize_fn = PermuteToFrom( VectorQuantize( out_chan, fq_dict_size)) if num_layer in fq_layers else None quantize_blocks.append(quantize_fn) else: quantize_blocks.append(None) self.blocks = nn.ModuleList(blocks) self.attn_blocks = nn.ModuleList(attn_blocks) self.quantize_blocks = nn.ModuleList(quantize_blocks) self.do_checkpointing = do_checkpointing chan_last = filters[-1] latent_dim = 2 * 2 * chan_last self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) self.flatten = Flatten() self.to_logit = nn.Linear(latent_dim, 1) self._init_weights()