def __init__(self,
                 image_size,
                 latent_dim,
                 network_capacity=16,
                 transparent=False,
                 attn_layers=[],
                 no_const=False,
                 fmap_max=512):
        super().__init__()
        self.image_size = image_size
        self.latent_dim = latent_dim
        self.num_layers = int(log2(image_size) - 1)

        filters = [
            network_capacity * (2**(i + 1)) for i in range(self.num_layers)
        ][::-1]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        init_channels = filters[0]
        filters = [init_channels, *filters]

        in_out_pairs = zip(filters[:-1], filters[1:])
        self.no_const = no_const

        if no_const:
            self.to_initial_block = nn.ConvTranspose2d(latent_dim,
                                                       init_channels,
                                                       4,
                                                       1,
                                                       0,
                                                       bias=False)
        else:
            self.initial_block = nn.Parameter(
                torch.randn((1, init_channels, 4, 4)))

        self.blocks = nn.ModuleList([])
        self.attns = nn.ModuleList([])

        for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
            not_first = ind != 0
            not_last = ind != (self.num_layers - 1)
            num_layer = self.num_layers - ind

            attn_fn = nn.Sequential(*[
                Residual(Rezero(ImageLinearAttention(in_chan)))
                for _ in range(2)
            ]) if num_layer in attn_layers else None

            self.attns.append(attn_fn)

            block = GeneratorBlock(latent_dim,
                                   in_chan,
                                   out_chan,
                                   upsample=not_first,
                                   upsample_rgb=not_last,
                                   rgba=transparent)
            self.blocks.append(block)
    def __init__(self,
                 image_size,
                 network_capacity=16,
                 fq_layers=[],
                 fq_dict_size=256,
                 attn_layers=[],
                 transparent=False,
                 fmap_max=512):
        super().__init__()
        num_layers = int(log2(image_size) - 1)
        num_init_filters = 3 if not transparent else 4

        blocks = []
        filters = [num_init_filters] + [(network_capacity) * (2**i)
                                        for i in range(num_layers + 1)]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        chan_in_out = list(zip(filters[:-1], filters[1:]))

        blocks = []
        quantize_blocks = []
        attn_blocks = []

        for ind, (in_chan, out_chan) in enumerate(chan_in_out):
            num_layer = ind + 1
            is_not_last = ind != (len(chan_in_out) - 1)

            block = DiscriminatorBlock(in_chan,
                                       out_chan,
                                       downsample=is_not_last)
            blocks.append(block)

            attn_fn = nn.Sequential(*[
                Residual(Rezero(ImageLinearAttention(out_chan)))
                for _ in range(2)
            ]) if num_layer in attn_layers else None

            attn_blocks.append(attn_fn)

            quantize_fn = PermuteToFrom(VectorQuantize(
                out_chan, fq_dict_size)) if num_layer in fq_layers else None
            quantize_blocks.append(quantize_fn)

        self.blocks = nn.ModuleList(blocks)
        self.attn_blocks = nn.ModuleList(attn_blocks)
        self.quantize_blocks = nn.ModuleList(quantize_blocks)

        latent_dim = 2 * 2 * filters[-1]

        self.flatten = Flatten()
        self.to_logit = nn.Linear(latent_dim, 1)
Пример #3
0
    def __init__(self,
                 image_size,
                 latent_dim,
                 network_capacity=16,
                 transparent=False,
                 attn_layers=[]):
        super().__init__()
        self.image_size = image_size
        self.latent_dim = latent_dim
        self.num_layers = int(log2(image_size) - 1)

        init_channels = 4 * network_capacity
        self.initial_block = nn.Parameter(torch.randn((init_channels, 4, 4)))
        filters = [init_channels] + [
            network_capacity * (2**(i + 1)) for i in range(self.num_layers)
        ][::-1]
        in_out_pairs = zip(filters[0:-1], filters[1:])

        self.blocks = nn.ModuleList([])
        self.attns = nn.ModuleList([])

        for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
            not_first = ind != 0
            not_last = ind != (self.num_layers - 1)
            num_layer = self.num_layers - ind

            attn_fn = (nn.Sequential(*[
                Residual(Rezero(ImageLinearAttention(in_chan)))
                for _ in range(2)
            ]) if num_layer in attn_layers else None)

            self.attns.append(attn_fn)

            block = GeneratorBlock(latent_dim,
                                   in_chan,
                                   out_chan,
                                   upsample=not_first,
                                   upsample_rgb=not_last,
                                   rgba=transparent)
            self.blocks.append(block)
class PermuteToFrom(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        out, loss = self.fn(x)
        out = out.permute(0, 3, 1, 2)
        return out, loss


# one layer of self-attention and feedforward, for images

attn_and_ff = lambda chan: nn.Sequential(*[
    Residual(Rezero(ImageLinearAttention(chan))),
    Residual(
        Rezero(
            nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(),
                          nn.Conv2d(chan * 2, chan, 1))))
])

# helpers


def default(value, d):
    return d if value is None else value


def cycle(iterable):
    while True:

class Rezero(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        self.g = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return self.fn(x) * self.g


# one layer of self-attention and feedforward, for images

attn_and_ff = lambda chan: nn.Sequential(*[
    Residual(Rezero(ImageLinearAttention(chan, norm_queries=True))),
    Residual(
        Rezero(
            nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(),
                          nn.Conv2d(chan * 2, chan, 1))))
])

# helpers


def default(value, d):
    return d if value is None else value


def cycle(iterable):
    while True: