Beispiel #1
0
            def __init__(self):
                super().__init__()

                class DeepProcessor(Module):
                    def __init__(self):
                        super().__init__()
                        hdim2 = 32
                        self.conv = Conv2d(in_channels=6,
                                           out_channels=32,
                                           kernel_size=3,
                                           padding=1)
                        self.gatedconvs = ModuleList([])
                        self.norm1 = ModuleList([])
                        for _ in range(dequant_blocks):
                            self.gatedconvs.append(
                                GatedConv(in_channels=hdim2,
                                          aux_channels=0,
                                          gate_nin=False,
                                          pdrop=pdrop))
                            self.norm1.append(ImgLayerNorm(hdim2))

                    def forward(self, x):
                        processed_context = self.conv(x)
                        for i in range(len(self.gatedconvs)):
                            processed_context = self.gatedconvs[i](
                                processed_context, aux=None)
                            processed_context = self.norm1[i](
                                processed_context)
                        return processed_context

                self.context_proc = DeepProcessor()

                self.noise_flow = Compose([
                    # input: Gaussian noise
                    StripeSplit(),
                    *coupling(
                        (3, 64, 32), for_dequant=True, attn_version=False),
                    *coupling(
                        (3, 64, 32), for_dequant=True, attn_version=False),
                    *coupling(
                        (3, 64, 32), for_dequant=True, attn_version=False),
                    *coupling(
                        (3, 64, 32), for_dequant=True, attn_version=False),
                    Inverse(StripeSplit()),
                    Sigmoid(),
                ])
                self.aux_split = StripeSplit()
 def __init__(self):
     super().__init__()
     self.context_proc = torch.nn.Sequential(
         Conv2d(in_channels=6,
                out_channels=32,
                kernel_size=3,
                padding=1),
         GatedConv(in_channels=32, aux_channels=0, pdrop=pdrop),
         GatedConv(in_channels=32, aux_channels=0, pdrop=pdrop),
         GatedConv(in_channels=32, aux_channels=0, pdrop=pdrop),
     )
     self.noise_flow = Compose([
         # input: Gaussian noise
         StripeSplit(),
         *coupling((3, 32, 16), for_dequant=True),
         *coupling((3, 32, 16), for_dequant=True),
         *coupling((3, 32, 16), for_dequant=True),
         *coupling((3, 32, 16), for_dequant=True),
         Inverse(StripeSplit()),
         Sigmoid(),
     ])
     self.aux_split = StripeSplit()
Beispiel #3
0
    def __init__(self,
                 *,
                 hdim=96,
                 blocks=16,
                 dequant_blocks=5,
                 mix_components=4,
                 attn_heads=4,
                 pdrop=0.,
                 force_float32_cond):
        def coupling(cf_shape_, for_dequant=False, attn_version=True):
            return [
                Parallel([lambda: Normalize(cf_shape_)] * 2),
                MixLogisticConvAttnCoupling_Imagenet64(
                    cf_shape=cf_shape_,
                    hidden_channels=hdim,
                    aux_channels=32 if for_dequant else 0,
                    blocks=dequant_blocks if for_dequant else blocks,
                    mix_components=mix_components,
                    attn_heads=attn_heads,
                    pdrop=pdrop,
                    force_float32_cond=force_float32_cond,
                    attn_version=attn_version),
                TupleFlip(),
            ]

        class Dequant(BaseFlow):
            def __init__(self):
                super().__init__()

                class DeepProcessor(Module):
                    def __init__(self):
                        super().__init__()
                        hdim2 = 32
                        self.conv = Conv2d(in_channels=6,
                                           out_channels=32,
                                           kernel_size=3,
                                           padding=1)
                        self.gatedconvs = ModuleList([])
                        self.norm1 = ModuleList([])
                        for _ in range(dequant_blocks):
                            self.gatedconvs.append(
                                GatedConv(in_channels=hdim2,
                                          aux_channels=0,
                                          gate_nin=False,
                                          pdrop=pdrop))
                            self.norm1.append(ImgLayerNorm(hdim2))

                    def forward(self, x):
                        processed_context = self.conv(x)
                        for i in range(len(self.gatedconvs)):
                            processed_context = self.gatedconvs[i](
                                processed_context, aux=None)
                            processed_context = self.norm1[i](
                                processed_context)
                        return processed_context

                self.context_proc = DeepProcessor()

                self.noise_flow = Compose([
                    # input: Gaussian noise
                    StripeSplit(),
                    *coupling(
                        (3, 64, 32), for_dequant=True, attn_version=False),
                    *coupling(
                        (3, 64, 32), for_dequant=True, attn_version=False),
                    *coupling(
                        (3, 64, 32), for_dequant=True, attn_version=False),
                    *coupling(
                        (3, 64, 32), for_dequant=True, attn_version=False),
                    Inverse(StripeSplit()),
                    Sigmoid(),
                ])
                self.aux_split = StripeSplit()

            def _process_context(self, aux):
                a = aux / 256.0 - 0.5
                a = torch.cat(self.aux_split(a, inverse=False, aux=None)[0],
                              dim=1)
                return self.context_proc(a)

            def forward(self, eps, *, aux, inverse: bool):
                # base distribution noise -> dequantization noise
                return self.noise_flow(eps,
                                       aux=self._process_context(aux),
                                       inverse=inverse)

            def code(self, input_sym, *, aux, inverse: bool, stream):
                return self.noise_flow.code(input_sym,
                                            aux=self._process_context(aux),
                                            inverse=inverse,
                                            stream=stream)

        super().__init__(
            main_flow=Compose([
                # input image 3, 64, 64
                ImgProc(),
                Squeeze(),
                # 12, 32, 32
                StripeSplit(),
                *coupling((12, 32, 16)),
                *coupling((12, 32, 16)),
                *coupling((12, 32, 16)),
                *coupling((12, 32, 16)),
                Inverse(StripeSplit()),

                # 12, 32, 32
                Squeeze(),
                # 48, 16, 16
                ChannelSplit(),
                *coupling((24, 16, 16)),
                *coupling((24, 16, 16)),
                Inverse(ChannelSplit()),
                StripeSplit(),
                *coupling((48, 16, 8)),
                *coupling((48, 16, 8)),
                Inverse(StripeSplit()),

                # 48, 16, 16
                Squeeze(),  # 192, 8, 8
                ChannelSplit(),
                *coupling((96, 8, 8)),
                *coupling((96, 8, 8)),
                Inverse(ChannelSplit()),
                StripeSplit(),
                *coupling((192, 8, 4)),
                *coupling((192, 8, 4)),
                Inverse(StripeSplit()),
            ]),
            dequant_flow=Dequant(),
            x_shape=(3, 64, 64),
            z_shape=(192, 8, 8))
    def __init__(self,
                 *,
                 hdim=96,
                 blocks=10,
                 dequant_blocks=2,
                 mix_components=32,
                 attn_heads=4,
                 pdrop=0.2,
                 force_float32_cond):
        def coupling(cf_shape_, for_dequant=False):
            return [
                Parallel([lambda: Normalize(cf_shape_)] * 2),
                Parallel([lambda: Pointwise(channels=cf_shape_[0])] * 2),
                MixLogisticConvAttnCoupling(
                    cf_shape=cf_shape_,
                    hidden_channels=hdim,
                    aux_channels=32 if for_dequant else 0,
                    blocks=dequant_blocks if for_dequant else blocks,
                    mix_components=mix_components,
                    attn_heads=attn_heads,
                    pdrop=pdrop,
                    force_float32_cond=force_float32_cond),
                TupleFlip(),
            ]

        class Dequant(BaseFlow):
            def __init__(self):
                super().__init__()
                self.context_proc = torch.nn.Sequential(
                    Conv2d(in_channels=6,
                           out_channels=32,
                           kernel_size=3,
                           padding=1),
                    GatedConv(in_channels=32, aux_channels=0, pdrop=pdrop),
                    GatedConv(in_channels=32, aux_channels=0, pdrop=pdrop),
                    GatedConv(in_channels=32, aux_channels=0, pdrop=pdrop),
                )
                self.noise_flow = Compose([
                    # input: Gaussian noise
                    StripeSplit(),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    Inverse(StripeSplit()),
                    Sigmoid(),
                ])
                self.aux_split = StripeSplit()

            def _process_context(self, aux):
                a = aux / 256.0 - 0.5
                a = torch.cat(self.aux_split(a, inverse=False, aux=None)[0],
                              dim=1)
                return self.context_proc(a)

            def forward(self, eps, *, aux, inverse: bool):
                # base distribution noise -> dequantization noise
                return self.noise_flow(eps,
                                       aux=self._process_context(aux),
                                       inverse=inverse)

            def code(self, input_sym, *, aux, inverse: bool, stream):
                return self.noise_flow.code(input_sym,
                                            aux=self._process_context(aux),
                                            inverse=inverse,
                                            stream=stream)

        super().__init__(
            main_flow=Compose([
                # input image 3, 32, 32
                ImgProc(),
                StripeSplit(),
                *coupling((3, 32, 16)),
                *coupling((3, 32, 16)),
                *coupling((3, 32, 16)),
                *coupling((3, 32, 16)),
                Inverse(StripeSplit()),
                Squeeze(),  # 12, 16, 16
                ChannelSplit(),
                *coupling((6, 16, 16)),
                *coupling((6, 16, 16)),
                Inverse(ChannelSplit()),
                StripeSplit(),
                *coupling((12, 16, 8)),
                *coupling((12, 16, 8)),
                *coupling((12, 16, 8)),
                Inverse(StripeSplit()),
            ]),
            dequant_flow=Dequant(),
            x_shape=(3, 32, 32),
            z_shape=(12, 16, 16))
            def __init__(self):
                super().__init__()

                class DeepProcessor(Module):
                    def __init__(self):
                        super().__init__()

                        hdim2 = 32
                        height = width = 32
                        pos_emb_init = 0.01
                        self.pos_emb = Parameter(
                            torch.Tensor(hdim2, height, width // 2))
                        torch.nn.init.normal_(self.pos_emb,
                                              mean=0.,
                                              std=pos_emb_init)
                        self.conv = WnConv2d(in_channels=6,
                                             out_channels=32,
                                             kernel_size=3,
                                             padding=1)

                        self.gatedconvs = ModuleList([])
                        self.norm1 = ModuleList([])
                        self.gatedattns = ModuleList([])
                        self.norm2 = ModuleList([])
                        for _ in range(dequant_blocks):
                            self.gatedconvs.append(
                                GatedConv_Imagenet32(in_channels=hdim2,
                                                     aux_channels=0,
                                                     gate_nin=False,
                                                     pdrop=pdrop))
                            self.norm1.append(ImgLayerNorm(hdim2))
                            self.gatedattns.append(
                                GatedAttention_Imagenet32(in_channels=hdim2,
                                                          heads=attn_heads,
                                                          pdrop=pdrop))
                            self.norm2.append(ImgLayerNorm(hdim2))

                    def forward(self, x):
                        processed_context = self.conv(x)
                        for i in range(len(self.gatedconvs)):
                            processed_context = self.gatedconvs[i](
                                processed_context, aux=None)
                            processed_context = self.norm1[i](
                                processed_context)
                            processed_context = self.gatedattns[i](
                                processed_context, pos_emb=self.pos_emb)
                            processed_context = self.norm2[i](
                                processed_context)
                        return processed_context

                self.context_proc = DeepProcessor()
                self.noise_flow = Compose([
                    # input: Gaussian noise
                    StripeSplit(),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    Inverse(StripeSplit()),
                    StripeSplit(),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    Inverse(StripeSplit()),
                    Sigmoid(),
                ])
                self.aux_split = StripeSplit()
    def __init__(self,
                 *,
                 hdim=128,
                 blocks=20,
                 dequant_blocks=8,
                 mix_components=32,
                 attn_heads=4,
                 pdrop=0.,
                 force_float32_cond):
        def coupling(cf_shape_, for_dequant=False):
            return [
                Parallel([lambda: Normalize(cf_shape_)] * 2),
                MixLogisticConvAttnCoupling_Imagenet32(
                    cf_shape=cf_shape_,
                    hidden_channels=hdim,
                    aux_channels=32 if for_dequant else 0,
                    blocks=dequant_blocks if for_dequant else blocks,
                    mix_components=mix_components,
                    attn_heads=attn_heads,
                    pdrop=pdrop,
                    force_float32_cond=force_float32_cond),
                TupleFlip(),
            ]

        class Dequant(BaseFlow):
            def __init__(self):
                super().__init__()

                class DeepProcessor(Module):
                    def __init__(self):
                        super().__init__()

                        hdim2 = 32
                        height = width = 32
                        pos_emb_init = 0.01
                        self.pos_emb = Parameter(
                            torch.Tensor(hdim2, height, width // 2))
                        torch.nn.init.normal_(self.pos_emb,
                                              mean=0.,
                                              std=pos_emb_init)
                        self.conv = WnConv2d(in_channels=6,
                                             out_channels=32,
                                             kernel_size=3,
                                             padding=1)

                        self.gatedconvs = ModuleList([])
                        self.norm1 = ModuleList([])
                        self.gatedattns = ModuleList([])
                        self.norm2 = ModuleList([])
                        for _ in range(dequant_blocks):
                            self.gatedconvs.append(
                                GatedConv_Imagenet32(in_channels=hdim2,
                                                     aux_channels=0,
                                                     gate_nin=False,
                                                     pdrop=pdrop))
                            self.norm1.append(ImgLayerNorm(hdim2))
                            self.gatedattns.append(
                                GatedAttention_Imagenet32(in_channels=hdim2,
                                                          heads=attn_heads,
                                                          pdrop=pdrop))
                            self.norm2.append(ImgLayerNorm(hdim2))

                    def forward(self, x):
                        processed_context = self.conv(x)
                        for i in range(len(self.gatedconvs)):
                            processed_context = self.gatedconvs[i](
                                processed_context, aux=None)
                            processed_context = self.norm1[i](
                                processed_context)
                            processed_context = self.gatedattns[i](
                                processed_context, pos_emb=self.pos_emb)
                            processed_context = self.norm2[i](
                                processed_context)
                        return processed_context

                self.context_proc = DeepProcessor()
                self.noise_flow = Compose([
                    # input: Gaussian noise
                    StripeSplit(),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    Inverse(StripeSplit()),
                    StripeSplit(),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    *coupling((3, 32, 16), for_dequant=True),
                    Inverse(StripeSplit()),
                    Sigmoid(),
                ])
                self.aux_split = StripeSplit()

            def _process_context(self, aux):
                a = aux / 256.0 - 0.5
                a = torch.cat(self.aux_split(a, inverse=False, aux=None)[0],
                              dim=1)
                return self.context_proc(a)

            def forward(self, eps, *, aux, inverse: bool):
                # base distribution noise -> dequantization noise
                return self.noise_flow(eps,
                                       aux=self._process_context(aux),
                                       inverse=inverse)

            def code(self, input_sym, *, aux, inverse: bool, stream):
                return self.noise_flow.code(input_sym,
                                            aux=self._process_context(aux),
                                            inverse=inverse,
                                            stream=stream)

        super().__init__(
            main_flow=Compose([
                # input image 3, 32, 32
                ImgProc(),
                StripeSplit(),
                *coupling((3, 32, 16)),
                *coupling((3, 32, 16)),
                *coupling((3, 32, 16)),
                *coupling((3, 32, 16)),
                Inverse(StripeSplit()),
                StripeSplit(),
                *coupling((3, 32, 16)),
                *coupling((3, 32, 16)),
                *coupling((3, 32, 16)),
                Inverse(StripeSplit()),
                Squeeze(),  # 12, 16, 16
                ChannelSplit(),
                *coupling((6, 16, 16)),
                *coupling((6, 16, 16)),
                *coupling((6, 16, 16)),
                Inverse(ChannelSplit()),
                StripeSplit(),
                *coupling((12, 16, 8)),
                *coupling((12, 16, 8)),
                *coupling((12, 16, 8)),
                Inverse(StripeSplit()),
            ]),
            dequant_flow=Dequant(),
            x_shape=(3, 32, 32),
            z_shape=(12, 16, 16))