Exemple #1
0
        def f(x, *, vcfg: VarConfig, context=None, dropout_p=0., verbose=True):
            if vcfg.init and verbose:
                # debug stuff
                xmean, xvar = tf.nn.moments(x, axes=list(range(len(x.shape))))
                x = tf.Print(x, [
                    tf.shape(x), xmean,
                    tf.sqrt(xvar),
                    tf.reduce_min(x),
                    tf.reduce_max(x)
                ],
                             message='{} (shape/mean/std/min/max) '.format(
                                 self.template.variable_scope.name),
                             summarize=10)
            B, H, W, C = x.shape.as_list()
            pos_emb = get_var(
                'pos_emb',
                shape=[H, W, filters],
                initializer=tf.random_normal_initializer(stddev=0.01),
                vcfg=vcfg)
            x = conv2d(x, name='proj_in', num_units=filters, vcfg=vcfg)
            for i_block in range(blocks):
                with tf.variable_scope(f'block{i_block}'):
                    x = gated_conv(x,
                                   name='conv',
                                   a=context,
                                   use_nin=True,
                                   dropout_p=dropout_p,
                                   vcfg=vcfg)
                    x = layernorm(x, name='ln1', vcfg=vcfg)
                    x = gated_nin(x,
                                  name='attn',
                                  pos_emb=pos_emb,
                                  dropout_p=dropout_p,
                                  vcfg=vcfg)
                    x = layernorm(x, name='ln2', vcfg=vcfg)
            x = conv2d(x,
                       name='proj_out',
                       num_units=C * (2 + 3 * components),
                       init_scale=init_scale,
                       vcfg=vcfg)
            assert x.shape == [B, H, W, C * (2 + 3 * components)]
            x = tf.reshape(x, [B, H, W, C, 2 + 3 * components])

            s, t = tf.tanh(x[:, :, :, :, 0]), x[:, :, :, :, 1]
            ml_logits, ml_means, ml_logscales = tf.split(x[:, :, :, :, 2:],
                                                         3,
                                                         axis=4)
            assert s.shape == t.shape == [B, H, W, C]
            assert ml_logits.shape == ml_means.shape == ml_logscales.shape == [
                B, H, W, C, components
            ]

            return Compose([
                MixLogisticCDF(logits=ml_logits,
                               means=ml_means,
                               logscales=ml_logscales),
                Inverse(Sigmoid()),
                ElemwiseAffine(scales=tf.exp(s), logscales=s, biases=t),
            ])
Exemple #2
0
def construct(*, filters, components, blocks):
    # see MixLogisticAttnCoupling constructor
    coupling_kwargs = dict(filters=filters,
                           blocks=blocks,
                           components=components)

    class UnifDequant(Flow):
        def forward(self, x, **kwargs):
            return x + tf.random_uniform(x.shape.as_list()), tf.zeros(
                [int(x.shape[0])])

    dequant_flow = UnifDequant()
    flow = Compose([
        ImgProc(),
        CheckerboardSplit(),
        Norm(),
        Pointwise(),
        MixLogisticAttnCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticAttnCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticAttnCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticAttnCoupling(**coupling_kwargs),
        TupleFlip(),
        Inverse(CheckerboardSplit()),
        SpaceToDepth(),
        ChannelSplit(),
        Norm(),
        Pointwise(),
        MixLogisticAttnCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticAttnCoupling(**coupling_kwargs),
        TupleFlip(),
        Inverse(ChannelSplit()),
        CheckerboardSplit(),
        Norm(),
        Pointwise(),
        MixLogisticAttnCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticAttnCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticAttnCoupling(**coupling_kwargs),
        TupleFlip(),
        Inverse(CheckerboardSplit()),
    ])
    return dequant_flow, flow
Exemple #3
0
        def __init__(self):
            def shallow_processor(x, *, dropout_p, vcfg):
                x = x / 256.0 - 0.5
                (this, that), _ = CheckerboardSplit().forward(x)
                x = conv2d(tf.concat([this, that], 3),
                           name='proj',
                           num_units=32,
                           vcfg=vcfg)
                for i in range(3):
                    x = gated_conv(x,
                                   name=f'c{i}',
                                   vcfg=vcfg,
                                   dropout_p=dropout_p,
                                   use_nin=False,
                                   a=None)
                return x

            self.context_proc = tf.make_template("context_proc",
                                                 shallow_processor)

            self.dequant_flow = Compose([
                CheckerboardSplit(),
                Norm(),
                Pointwise(),
                AffineAttnCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Norm(),
                Pointwise(),
                AffineAttnCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Norm(),
                Pointwise(),
                AffineAttnCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Norm(),
                Pointwise(),
                AffineAttnCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Inverse(CheckerboardSplit()),
                Sigmoid(),
            ])
Exemple #4
0
    class Dequant(Flow):
        def __init__(self):
            def shallow_processor(x, *, dropout_p, vcfg):
                x = x / 256.0 - 0.5
                (this, that), _ = CheckerboardSplit().forward(x)
                x = conv2d(tf.concat([this, that], 3),
                           name='proj',
                           num_units=32,
                           vcfg=vcfg)
                for i in range(3):
                    x = gated_conv(x,
                                   name=f'c{i}',
                                   vcfg=vcfg,
                                   dropout_p=dropout_p,
                                   use_nin=False,
                                   a=None)
                return x

            self.context_proc = tf.make_template("context_proc",
                                                 shallow_processor)

            self.dequant_flow = Compose([
                CheckerboardSplit(),
                Norm(),
                Pointwise(),
                AffineAttnCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Norm(),
                Pointwise(),
                AffineAttnCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Norm(),
                Pointwise(),
                AffineAttnCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Norm(),
                Pointwise(),
                AffineAttnCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Inverse(CheckerboardSplit()),
                Sigmoid(),
            ])

        def forward(self,
                    x,
                    *,
                    vcfg,
                    dropout_p=0.,
                    verbose=True,
                    context=None):
            assert context is None
            eps, eps_logp = gaussian_sample_logp(x.shape.as_list())
            xd, logd = self.dequant_flow.forward(eps,
                                                 context=self.context_proc(
                                                     x,
                                                     dropout_p=dropout_p,
                                                     vcfg=vcfg),
                                                 dropout_p=dropout_p,
                                                 verbose=verbose,
                                                 vcfg=vcfg)
            assert eps.shape == x.shape and logd.shape == eps_logp.shape == [
                x.shape[0]
            ]
            return x + xd, logd - eps_logp
Exemple #5
0
def construct(*, filters, dequant_filters, components, blocks):
    # see MixLogisticAttnCoupling constructor
    dequant_coupling_kwargs = dict(filters=dequant_filters,
                                   blocks=2,
                                   components=components)
    coupling_kwargs = dict(filters=filters,
                           blocks=blocks,
                           components=components)

    class Dequant(Flow):
        def __init__(self):
            def shallow_processor(x, *, dropout_p, vcfg):
                x = x / 256.0 - 0.5
                (this, that), _ = CheckerboardSplit().forward(x)
                x = conv2d(tf.concat([this, that], 3),
                           name='proj',
                           num_units=32,
                           vcfg=vcfg)
                for i in range(3):
                    x = gated_conv(x,
                                   name=f'c{i}',
                                   vcfg=vcfg,
                                   dropout_p=dropout_p,
                                   use_nin=False,
                                   a=None)
                return x

            self.context_proc = tf.make_template("context_proc",
                                                 shallow_processor)

            self.dequant_flow = Compose([
                CheckerboardSplit(),
                Norm(),
                Pointwise(),
                MixLogisticCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Norm(),
                Pointwise(),
                MixLogisticCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Norm(),
                Pointwise(),
                MixLogisticCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Norm(),
                Pointwise(),
                MixLogisticCoupling(**dequant_coupling_kwargs),
                TupleFlip(),
                Inverse(CheckerboardSplit()),
                Sigmoid(),
            ])

        def forward(self,
                    x,
                    *,
                    vcfg,
                    dropout_p=0.,
                    verbose=True,
                    context=None):
            assert context is None
            eps, eps_logp = gaussian_sample_logp(x.shape.as_list())
            xd, logd = self.dequant_flow.forward(eps,
                                                 context=self.context_proc(
                                                     x,
                                                     dropout_p=dropout_p,
                                                     vcfg=vcfg),
                                                 dropout_p=dropout_p,
                                                 verbose=verbose,
                                                 vcfg=vcfg)
            assert eps.shape == x.shape and logd.shape == eps_logp.shape == [
                x.shape[0]
            ]
            return x + xd, logd - eps_logp

    dequant_flow = Dequant()
    flow = Compose([
        ImgProc(),
        CheckerboardSplit(),
        Norm(),
        Pointwise(),
        MixLogisticCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticCoupling(**coupling_kwargs),
        TupleFlip(),
        Inverse(CheckerboardSplit()),
        SpaceToDepth(),
        ChannelSplit(),
        Norm(),
        Pointwise(),
        MixLogisticCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticCoupling(**coupling_kwargs),
        TupleFlip(),
        Inverse(ChannelSplit()),
        CheckerboardSplit(),
        Norm(),
        Pointwise(),
        MixLogisticCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticCoupling(**coupling_kwargs),
        TupleFlip(),
        Norm(),
        Pointwise(),
        MixLogisticCoupling(**coupling_kwargs),
        TupleFlip(),
        Inverse(CheckerboardSplit()),
    ])
    return dequant_flow, flow