Exemple #1
0
    def __init__(self,
                 num_layers=4,
                 network_capacity=16,
                 fq_layers=[],
                 fq_dict_size=256,
                 attn_layers=[],
                 transparent=False,
                 fmap_max=512):
        super().__init__()

        n_downsample = 4
        num_layers = n_downsample + 1
        network_capacity = 16 * 2  # fpmax = 512 so the max filter size is clipped to 512

        # num_layers = int(log2(image_size) - 1)
        num_init_filters = 6 if not transparent else 8

        blocks = []
        filters = [num_init_filters] + [(network_capacity * 4) * (2**i)
                                        for i in range(num_layers + 1)]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        chan_in_out = list(zip(filters[:-1], filters[1:]))

        blocks = []
        attn_blocks = []
        quantize_blocks = []

        for ind, (in_chan, out_chan) in enumerate(chan_in_out):
            num_layer = ind + 1
            is_not_last = ind != (len(chan_in_out) - 1)

            block = DiscriminatorBlock(in_chan,
                                       out_chan,
                                       downsample=is_not_last)
            blocks.append(block)

            attn_fn = attn_and_ff(
                out_chan) if num_layer in attn_layers else None

            attn_blocks.append(attn_fn)

            quantize_fn = PermuteToFrom(VectorQuantize(
                out_chan, fq_dict_size)) if num_layer in fq_layers else None
            quantize_blocks.append(quantize_fn)

        self.blocks = nn.ModuleList(blocks)
        self.attn_blocks = nn.ModuleList(attn_blocks)
        self.quantize_blocks = nn.ModuleList(quantize_blocks)

        chan_last = filters[-1]
        latent_dim = 2 * 2 * chan_last

        self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
        self.sigmoid = nn.Sigmoid()
    def __init__(self,
                 image_size,
                 network_capacity=16,
                 fq_layers=[],
                 fq_dict_size=256,
                 attn_layers=[],
                 transparent=False,
                 fmap_max=512):
        super().__init__()
        num_layers = int(log2(image_size) - 1)
        num_init_filters = 3 if not transparent else 4

        blocks = []
        filters = [num_init_filters] + [(network_capacity) * (2**i)
                                        for i in range(num_layers + 1)]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        chan_in_out = list(zip(filters[:-1], filters[1:]))

        blocks = []
        quantize_blocks = []
        attn_blocks = []

        for ind, (in_chan, out_chan) in enumerate(chan_in_out):
            num_layer = ind + 1
            is_not_last = ind != (len(chan_in_out) - 1)

            block = DiscriminatorBlock(in_chan,
                                       out_chan,
                                       downsample=is_not_last)
            blocks.append(block)

            attn_fn = nn.Sequential(*[
                Residual(Rezero(ImageLinearAttention(out_chan)))
                for _ in range(2)
            ]) if num_layer in attn_layers else None

            attn_blocks.append(attn_fn)

            quantize_fn = PermuteToFrom(VectorQuantize(
                out_chan, fq_dict_size)) if num_layer in fq_layers else None
            quantize_blocks.append(quantize_fn)

        self.blocks = nn.ModuleList(blocks)
        self.attn_blocks = nn.ModuleList(attn_blocks)
        self.quantize_blocks = nn.ModuleList(quantize_blocks)

        latent_dim = 2 * 2 * filters[-1]

        self.flatten = Flatten()
        self.to_logit = nn.Linear(latent_dim, 1)
    def __init__(self,
                 image_size,
                 network_capacity=16,
                 fq_layers=[],
                 fq_dict_size=256,
                 attn_layers=[],
                 transparent=False,
                 fmap_max=512,
                 input_filters=3,
                 quantize=False,
                 do_checkpointing=False,
                 mlp=False,
                 transfer_mode=False):
        super().__init__()
        num_layers = int(log2(image_size) - 1)

        blocks = []
        filters = [input_filters] + [(64) * (2**i)
                                     for i in range(num_layers + 1)]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        chan_in_out = list(zip(filters[:-1], filters[1:]))

        blocks = []
        attn_blocks = []
        quantize_blocks = []

        for ind, (in_chan, out_chan) in enumerate(chan_in_out):
            num_layer = ind + 1
            is_not_last = ind != (len(chan_in_out) - 1)

            block = DiscriminatorBlock(in_chan,
                                       out_chan,
                                       downsample=is_not_last,
                                       transfer_mode=transfer_mode)
            blocks.append(block)

            attn_fn = attn_and_ff(
                out_chan) if num_layer in attn_layers else None

            attn_blocks.append(attn_fn)

            if quantize:
                quantize_fn = PermuteToFrom(
                    VectorQuantize(
                        out_chan,
                        fq_dict_size)) if num_layer in fq_layers else None
                quantize_blocks.append(quantize_fn)
            else:
                quantize_blocks.append(None)

        self.blocks = nn.ModuleList(blocks)
        self.attn_blocks = nn.ModuleList(attn_blocks)
        self.quantize_blocks = nn.ModuleList(quantize_blocks)
        self.do_checkpointing = do_checkpointing

        chan_last = filters[-1]
        latent_dim = 2 * 2 * chan_last

        self.final_conv = TransferConv2d(chan_last,
                                         chan_last,
                                         3,
                                         padding=1,
                                         transfer_mode=transfer_mode)
        self.flatten = nn.Flatten()
        if mlp:
            self.to_logit = nn.Sequential(
                TransferLinear(latent_dim, 100, transfer_mode=transfer_mode),
                leaky_relu(),
                TransferLinear(100, 1, transfer_mode=transfer_mode))
        else:
            self.to_logit = TransferLinear(latent_dim,
                                           1,
                                           transfer_mode=transfer_mode)

        self._init_weights()

        self.transfer_mode = transfer_mode
        if transfer_mode:
            for p in self.parameters():
                if not hasattr(p, 'FOR_TRANSFER_LEARNING'):
                    p.DO_NOT_TRAIN = True
Exemple #4
0
    def __init__(self,
                 image_size=256,
                 num_tokens=512,
                 codebook_dim=512,
                 num_layers=3,
                 num_resnet_blocks=0,
                 hidden_dim=64,
                 channels=3,
                 smooth_l1_loss=False,
                 vq_decay=0.8,
                 commitment_weight=1.):
        super().__init__()
        assert log2(image_size).is_integer(), 'image size must be a power of 2'
        assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
        has_resblocks = num_resnet_blocks > 0

        self.image_size = image_size
        self.num_tokens = num_tokens
        self.num_layers = num_layers

        self.vq = VectorQuantize(dim=codebook_dim,
                                 n_embed=num_tokens,
                                 decay=vq_decay,
                                 commitment=commitment_weight)

        hdim = hidden_dim

        enc_chans = [hidden_dim] * num_layers
        dec_chans = list(reversed(enc_chans))

        enc_chans = [channels, *enc_chans]

        dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
        dec_chans = [dec_init_chan, *dec_chans]

        enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])),
                                         (enc_chans, dec_chans))

        enc_layers = []
        dec_layers = []

        for (enc_in, enc_out), (dec_in,
                                dec_out) in zip(enc_chans_io, dec_chans_io):
            enc_layers.append(
                nn.Sequential(
                    nn.Conv2d(enc_in, enc_out, 4, stride=2, padding=1),
                    nn.ReLU()))
            dec_layers.append(
                nn.Sequential(
                    nn.ConvTranspose2d(dec_in, dec_out, 4, stride=2,
                                       padding=1), nn.ReLU()))

        for _ in range(num_resnet_blocks):
            dec_layers.insert(0, ResBlock(dec_chans[1]))
            enc_layers.append(ResBlock(enc_chans[-1]))

        if num_resnet_blocks > 0:
            dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))

        enc_layers.append(nn.Conv2d(enc_chans[-1], codebook_dim, 1))
        dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))

        self.encoder = nn.Sequential(*enc_layers)
        self.decoder = nn.Sequential(*dec_layers)

        self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
    def __init__(self,
                 image_size,
                 network_capacity=16,
                 fq_layers=[],
                 fq_dict_size=256,
                 attn_layers=[],
                 transparent=False,
                 fmap_max=512,
                 input_filters=3,
                 quantize=False,
                 do_checkpointing=False):
        super().__init__()
        num_layers = int(log2(image_size) - 1)

        blocks = []
        filters = [input_filters] + [(64) * (2**i)
                                     for i in range(num_layers + 1)]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        chan_in_out = list(zip(filters[:-1], filters[1:]))

        blocks = []
        attn_blocks = []
        quantize_blocks = []

        for ind, (in_chan, out_chan) in enumerate(chan_in_out):
            num_layer = ind + 1
            is_not_last = ind != (len(chan_in_out) - 1)

            block = DiscriminatorBlock(in_chan,
                                       out_chan,
                                       downsample=is_not_last)
            blocks.append(block)

            attn_fn = attn_and_ff(
                out_chan) if num_layer in attn_layers else None

            attn_blocks.append(attn_fn)

            if quantize:
                quantize_fn = PermuteToFrom(
                    VectorQuantize(
                        out_chan,
                        fq_dict_size)) if num_layer in fq_layers else None
                quantize_blocks.append(quantize_fn)
            else:
                quantize_blocks.append(None)

        self.blocks = nn.ModuleList(blocks)
        self.attn_blocks = nn.ModuleList(attn_blocks)
        self.quantize_blocks = nn.ModuleList(quantize_blocks)
        self.do_checkpointing = do_checkpointing

        chan_last = filters[-1]
        latent_dim = 2 * 2 * chan_last

        self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1)
        self.flatten = Flatten()
        self.to_logit = nn.Linear(latent_dim, 1)

        self._init_weights()