def f(x, *, vcfg: VarConfig, context=None, dropout_p=0., verbose=True): if vcfg.init and verbose: # debug stuff xmean, xvar = tf.nn.moments(x, axes=list(range(len(x.shape)))) x = tf.Print(x, [ tf.shape(x), xmean, tf.sqrt(xvar), tf.reduce_min(x), tf.reduce_max(x) ], message='{} (shape/mean/std/min/max) '.format( self.template.variable_scope.name), summarize=10) B, H, W, C = x.shape.as_list() pos_emb = get_var( 'pos_emb', shape=[H, W, filters], initializer=tf.random_normal_initializer(stddev=0.01), vcfg=vcfg) x = conv2d(x, name='proj_in', num_units=filters, vcfg=vcfg) for i_block in range(blocks): with tf.variable_scope(f'block{i_block}'): x = gated_conv(x, name='conv', a=context, use_nin=True, dropout_p=dropout_p, vcfg=vcfg) x = layernorm(x, name='ln1', vcfg=vcfg) x = gated_nin(x, name='attn', pos_emb=pos_emb, dropout_p=dropout_p, vcfg=vcfg) x = layernorm(x, name='ln2', vcfg=vcfg) x = conv2d(x, name='proj_out', num_units=C * (2 + 3 * components), init_scale=init_scale, vcfg=vcfg) assert x.shape == [B, H, W, C * (2 + 3 * components)] x = tf.reshape(x, [B, H, W, C, 2 + 3 * components]) s, t = tf.tanh(x[:, :, :, :, 0]), x[:, :, :, :, 1] ml_logits, ml_means, ml_logscales = tf.split(x[:, :, :, :, 2:], 3, axis=4) assert s.shape == t.shape == [B, H, W, C] assert ml_logits.shape == ml_means.shape == ml_logscales.shape == [ B, H, W, C, components ] return Compose([ MixLogisticCDF(logits=ml_logits, means=ml_means, logscales=ml_logscales), Inverse(Sigmoid()), ElemwiseAffine(scales=tf.exp(s), logscales=s, biases=t), ])
def construct(*, filters, components, blocks): # see MixLogisticAttnCoupling constructor coupling_kwargs = dict(filters=filters, blocks=blocks, components=components) class UnifDequant(Flow): def forward(self, x, **kwargs): return x + tf.random_uniform(x.shape.as_list()), tf.zeros( [int(x.shape[0])]) dequant_flow = UnifDequant() flow = Compose([ ImgProc(), CheckerboardSplit(), Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(), Inverse(CheckerboardSplit()), SpaceToDepth(), ChannelSplit(), Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(), Inverse(ChannelSplit()), CheckerboardSplit(), Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_kwargs), TupleFlip(), Inverse(CheckerboardSplit()), ]) return dequant_flow, flow
def __init__(self): def shallow_processor(x, *, dropout_p, vcfg): x = x / 256.0 - 0.5 (this, that), _ = CheckerboardSplit().forward(x) x = conv2d(tf.concat([this, that], 3), name='proj', num_units=32, vcfg=vcfg) for i in range(3): x = gated_conv(x, name=f'c{i}', vcfg=vcfg, dropout_p=dropout_p, use_nin=False, a=None) return x self.context_proc = tf.make_template("context_proc", shallow_processor) self.dequant_flow = Compose([ CheckerboardSplit(), Norm(), Pointwise(), AffineAttnCoupling(**dequant_coupling_kwargs), TupleFlip(), Norm(), Pointwise(), AffineAttnCoupling(**dequant_coupling_kwargs), TupleFlip(), Norm(), Pointwise(), AffineAttnCoupling(**dequant_coupling_kwargs), TupleFlip(), Norm(), Pointwise(), AffineAttnCoupling(**dequant_coupling_kwargs), TupleFlip(), Inverse(CheckerboardSplit()), Sigmoid(), ])
class Dequant(Flow): def __init__(self): def shallow_processor(x, *, dropout_p, vcfg): x = x / 256.0 - 0.5 (this, that), _ = CheckerboardSplit().forward(x) x = conv2d(tf.concat([this, that], 3), name='proj', num_units=32, vcfg=vcfg) for i in range(3): x = gated_conv(x, name=f'c{i}', vcfg=vcfg, dropout_p=dropout_p, use_nin=False, a=None) return x self.context_proc = tf.make_template("context_proc", shallow_processor) self.dequant_flow = Compose([ CheckerboardSplit(), Norm(), Pointwise(), AffineAttnCoupling(**dequant_coupling_kwargs), TupleFlip(), Norm(), Pointwise(), AffineAttnCoupling(**dequant_coupling_kwargs), TupleFlip(), Norm(), Pointwise(), AffineAttnCoupling(**dequant_coupling_kwargs), TupleFlip(), Norm(), Pointwise(), AffineAttnCoupling(**dequant_coupling_kwargs), TupleFlip(), Inverse(CheckerboardSplit()), Sigmoid(), ]) def forward(self, x, *, vcfg, dropout_p=0., verbose=True, context=None): assert context is None eps, eps_logp = gaussian_sample_logp(x.shape.as_list()) xd, logd = self.dequant_flow.forward(eps, context=self.context_proc( x, dropout_p=dropout_p, vcfg=vcfg), dropout_p=dropout_p, verbose=verbose, vcfg=vcfg) assert eps.shape == x.shape and logd.shape == eps_logp.shape == [ x.shape[0] ] return x + xd, logd - eps_logp
def construct(*, filters, dequant_filters, components, blocks): # see MixLogisticAttnCoupling constructor dequant_coupling_kwargs = dict(filters=dequant_filters, blocks=2, components=components) coupling_kwargs = dict(filters=filters, blocks=blocks, components=components) class Dequant(Flow): def __init__(self): def shallow_processor(x, *, dropout_p, vcfg): x = x / 256.0 - 0.5 (this, that), _ = CheckerboardSplit().forward(x) x = conv2d(tf.concat([this, that], 3), name='proj', num_units=32, vcfg=vcfg) for i in range(3): x = gated_conv(x, name=f'c{i}', vcfg=vcfg, dropout_p=dropout_p, use_nin=False, a=None) return x self.context_proc = tf.make_template("context_proc", shallow_processor) self.dequant_flow = Compose([ CheckerboardSplit(), Norm(), Pointwise(), MixLogisticCoupling(**dequant_coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticCoupling(**dequant_coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticCoupling(**dequant_coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticCoupling(**dequant_coupling_kwargs), TupleFlip(), Inverse(CheckerboardSplit()), Sigmoid(), ]) def forward(self, x, *, vcfg, dropout_p=0., verbose=True, context=None): assert context is None eps, eps_logp = gaussian_sample_logp(x.shape.as_list()) xd, logd = self.dequant_flow.forward(eps, context=self.context_proc( x, dropout_p=dropout_p, vcfg=vcfg), dropout_p=dropout_p, verbose=verbose, vcfg=vcfg) assert eps.shape == x.shape and logd.shape == eps_logp.shape == [ x.shape[0] ] return x + xd, logd - eps_logp dequant_flow = Dequant() flow = Compose([ ImgProc(), CheckerboardSplit(), Norm(), Pointwise(), MixLogisticCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticCoupling(**coupling_kwargs), TupleFlip(), Inverse(CheckerboardSplit()), SpaceToDepth(), ChannelSplit(), Norm(), Pointwise(), MixLogisticCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticCoupling(**coupling_kwargs), TupleFlip(), Inverse(ChannelSplit()), CheckerboardSplit(), Norm(), Pointwise(), MixLogisticCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticCoupling(**coupling_kwargs), TupleFlip(), Norm(), Pointwise(), MixLogisticCoupling(**coupling_kwargs), TupleFlip(), Inverse(CheckerboardSplit()), ]) return dequant_flow, flow