Beispiel #1
0
    def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, init_zeros=False,
            st_type='resnet', use_batch_norm=True, img_shape=(1, 28, 28), skip=True, latent_dim=100):
        super(RealNVP, self).__init__()

        layers = [addZslot(), passThrough(iLogits())]
        self.output_shapes = []
        _, _, img_width = img_shape

        for scale in range(num_scales):
            in_couplings = self._threecouplinglayers(in_channels, mid_channels, num_blocks, MaskCheckerboard,
                init_zeros, st_type, use_batch_norm, img_width, skip, latent_dim)
            layers.append(passThrough(*in_couplings))

            if scale == num_scales - 1:
                layers.append(passThrough(
                    CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True),
                        init_zeros, st_type, use_batch_norm, img_width, skip, latent_dim)))
            else:
                layers.append(passThrough(SqueezeLayer(2)))
                img_width = img_width // 2
                if st_type != 'autoencoder':  # in the autoencoder case we probably want the bottleneck size to be fixed?
                    mid_channels *= 2
                out_couplings = self._threecouplinglayers(4*in_channels, mid_channels, num_blocks,
                    MaskChannelwise, init_zeros, st_type, use_batch_norm, img_width, skip, latent_dim)
                layers.append(passThrough(*out_couplings))
                layers.append(keepChannels(2*in_channels))
                in_channels *= 2

        layers.append(FlatJoin())
        self.body = iSequential(*layers)
Beispiel #2
0
    def __init__(self,
                 num_scales=2,
                 in_channels=3,
                 mid_channels=64,
                 num_blocks=8):
        super(RealNVP, self).__init__()

        layers = [addZslot(), passThrough(iLogits())]

        for scale in range(num_scales):
            in_couplings = self._threecouplinglayers(in_channels, mid_channels,
                                                     num_blocks,
                                                     MaskCheckerboard)
            layers.append(passThrough(*in_couplings))

            if scale == num_scales - 1:
                layers.append(
                    passThrough(
                        CouplingLayer(in_channels, mid_channels, num_blocks,
                                      MaskCheckerboard(reverse_mask=True))))
            else:
                layers.append(passThrough(SqueezeLayer(2)))
                out_couplings = self._threecouplinglayers(
                    4 * in_channels, 2 * mid_channels, num_blocks,
                    MaskChannelwise)
                layers.append(passThrough(*out_couplings))
                layers.append(keepChannels(2 * in_channels))

            in_channels *= 2
            mid_channels *= 2

        layers.append(FlatJoin())
        self.body = iSequential(*layers)
Beispiel #3
0
    def __init__(self,
                 num_scales=2,
                 in_channels=3,
                 mid_channels=64,
                 num_blocks=8):
        super(Glow, self).__init__()

        layers = [addZslot(), passThrough(iLogits())]

        for scale in range(num_scales):
            num_in = 4 if scale == num_scales - 1 else 3
            for _ in range(num_in):
                layers.append(
                    passThrough(*self._glow_step(in_channels, mid_channels,
                                                 num_blocks)))
            layers.append(passThrough(SqueezeLayer(2)))
            num_out = 0 if scale == num_scales - 1 else 3
            for _ in range(num_in):
                layers.append(
                    passThrough(*self._glow_step(4 * in_channels, 2 *
                                                 mid_channels, num_blocks)))
            layers.append(keepChannels(2 * in_channels))

            in_channels *= 2
            mid_channels *= 2

        layers.append(FlatJoin())
        self.body = iSequential(*layers)
Beispiel #4
0
    def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, init_zeros=False,
            st_type='resnet', use_batch_norm=True, img_shape=(1, 28, 28), skip=True, latent_dim=None):
        super(RealNVP8Layers, self).__init__()

        self.body = iSequential(
                addZslot(),
                passThrough(iLogits()),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False),
                    init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True),
                    init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False),
                    init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True),
                    init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False),
                    init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True),
                    init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False),
                    init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True),
                    init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])),
                FlatJoin()
            )
Beispiel #5
0
    def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, init_zeros=False,
            st_type='resnet', use_batch_norm=True, img_shape=(1, 28, 28), skip=True, latent_dim=None):
        super(RealNVPCycleMask, self).__init__()

        self.body = iSequential(
                addZslot(),
                passThrough(iLogits()),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=0, output_quadrant=1), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=1, output_quadrant=2), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=2, output_quadrant=3), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=3, output_quadrant=0), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=0, output_quadrant=1), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=1, output_quadrant=2), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=2, output_quadrant=3), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)),
                passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=3, output_quadrant=0), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)),
                FlatJoin()
            )
Beispiel #6
0
    def __init__(self,
                 image_shape,
                 mid_channels=64,
                 num_scales=2,
                 num_coupling_layers_per_scale=8,
                 num_layers=3,
                 actnorm_scale=1.,
                 multi_scale=True,
                 st_type='glow_convnet'):
        super(Glow, self).__init__()

        layers = [addZslot(), passThrough(iLogits())]
        self.output_shapes = []

        C, H, W = image_shape

        for scale in range(num_scales):
            # Squeeze
            C, H, W = C * 4, H // 2, W // 2
            layers.append(passThrough(SqueezeLayer(downscale_factor=2)))
            self.output_shapes.append([-1, C, H, W])

            # Flow steps
            for _ in range(num_coupling_layers_per_scale):
                layers.append(
                    passThrough(*self._glow_step(in_channels=C,
                                                 mid_channels=mid_channels,
                                                 actnorm_scale=actnorm_scale,
                                                 st_type=st_type,
                                                 num_layers=num_layers)))
                self.output_shapes.append([-1, C, H, W])

            # Split and factor out
            if multi_scale:
                if scale < num_scales - 1:
                    layers.append(keepChannels(C // 2))
                    self.output_shapes.append([-1, C // 2, H, W])
                    C = C // 2

        layers.append(FlatJoin())
        self.body = iSequential(*layers)
Beispiel #7
0
 def _threecouplinglayers(in_channels,
                          mid_channels,
                          num_blocks,
                          mask_class,
                          num_classes=10):
     layers = []
     for i in range(3):
         layers.append(
             passThrough(
                 iConv1x1(in_channels),
                 CouplingLayer(in_channels, mid_channels, num_blocks,
                               mask_class(reverse_mask=not i % 2)),
                 ActNorm2d(in_channels)))
         layers.append(iCategoricalFiLM(num_classes, in_channels))
     return layers
Beispiel #8
0
    def __init__(self, in_channels=1, mid_channels=64, num_blocks=4):
        super(RealNVPMNIST, self).__init__()

        self.body = iSequential(
            addZslot(), passThrough(iLogits()),
            passThrough(
                CouplingLayer(in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=False))),
            passThrough(
                CouplingLayer(in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=True))),
            passThrough(
                CouplingLayer(in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=False))),
            passThrough(SqueezeLayer(2)),
            passThrough(
                CouplingLayer(4 * in_channels, mid_channels, num_blocks,
                              MaskChannelwise(reverse_mask=False))),
            passThrough(
                CouplingLayer(4 * in_channels, mid_channels, num_blocks,
                              MaskChannelwise(reverse_mask=True))),
            passThrough(
                CouplingLayer(4 * in_channels, mid_channels, num_blocks,
                              MaskChannelwise(reverse_mask=False))),
            keepChannels(2 * in_channels),
            passThrough(
                CouplingLayer(2 * in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=False))),
            passThrough(
                CouplingLayer(2 * in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=True))),
            passThrough(
                CouplingLayer(2 * in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=False))),
            passThrough(SqueezeLayer(2)),
            passThrough(
                CouplingLayer(8 * in_channels, mid_channels, num_blocks,
                              MaskChannelwise(reverse_mask=False))),
            passThrough(
                CouplingLayer(8 * in_channels, mid_channels, num_blocks,
                              MaskChannelwise(reverse_mask=True))),
            passThrough(
                CouplingLayer(8 * in_channels, mid_channels, num_blocks,
                              MaskChannelwise(reverse_mask=False))),
            keepChannels(4 * in_channels),
            passThrough(
                CouplingLayer(4 * in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=False))),
            passThrough(
                CouplingLayer(4 * in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=True))),
            passThrough(
                CouplingLayer(4 * in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=False))),
            passThrough(
                CouplingLayer(4 * in_channels, mid_channels, num_blocks,
                              MaskCheckerboard(reverse_mask=True))),
            FlatJoin())
Beispiel #9
0
    def __init__(self, in_channels=1, mid_channels=64, num_blocks=4):
        super(GlowMNIST, self).__init__()

        self.body = iSequential(
            addZslot(), passThrough(iLogits()),
            passThrough(*self._glow_step(in_channels, mid_channels, num_blocks,
                                         MaskCheckerboard)),
            passThrough(*self._glow_step(in_channels, mid_channels, num_blocks,
                                         MaskCheckerboard)),
            passThrough(*self._glow_step(in_channels, mid_channels, num_blocks,
                                         MaskCheckerboard)),
            passThrough(SqueezeLayer(2)),
            passThrough(*self._glow_step(4 * in_channels, mid_channels,
                                         num_blocks, MaskChannelwise)),
            passThrough(*self._glow_step(4 * in_channels, mid_channels,
                                         num_blocks, MaskChannelwise)),
            keepChannels(2 * in_channels),
            passThrough(*self._glow_step(2 * in_channels, mid_channels,
                                         num_blocks, MaskCheckerboard)),
            passThrough(*self._glow_step(2 * in_channels, mid_channels,
                                         num_blocks, MaskCheckerboard)),
            passThrough(*self._glow_step(2 * in_channels, mid_channels,
                                         num_blocks, MaskCheckerboard)),
            passThrough(SqueezeLayer(2)),
            passThrough(*self._glow_step(8 * in_channels, mid_channels,
                                         num_blocks, MaskChannelwise)),
            passThrough(*self._glow_step(8 * in_channels, mid_channels,
                                         num_blocks, MaskChannelwise)),
            keepChannels(4 * in_channels),
            passThrough(*self._glow_step(4 * in_channels, mid_channels,
                                         num_blocks, MaskCheckerboard)),
            passThrough(*self._glow_step(4 * in_channels, mid_channels,
                                         num_blocks, MaskCheckerboard)),
            passThrough(*self._glow_step(4 * in_channels, mid_channels,
                                         num_blocks, MaskCheckerboard)),
            passThrough(*self._glow_step(4 * in_channels, mid_channels,
                                         num_blocks, MaskChannelwise)),
            FlatJoin())