def revnet2d_step(name, z, logdet, hps, reverse): with tf.variable_scope(name): shape = Z.int_shape(z) n_z = shape[3] assert n_z % 2 == 0 if not reverse: z, logdet = Z.actnorm("actnorm", z, logdet=logdet) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv("invconv", z, logdet) else: raise Exception() z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 += f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) z2 += shift z2 *= scale logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) else: z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 -= f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) z2 /= scale z2 -= shift logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z, reverse=True) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z, reverse=True) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True) else: raise Exception() z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True) return z, logdet
def revnet2d_step(name, z, logdet, hps, reverse): with tf.variable_scope(name): shape = Z.int_shape(z) n_z = shape[3] assert n_z % 2 == 0 if not reverse: z, logdet = Z.actnorm("actnorm", z, logdet=logdet) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition) elif hps.flow_permutation == 3: z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition) z, logdet = invertible_conv2D_emerging( "emerging", z, logdet, checkpoint_fn=checkpoint) elif hps.flow_permutation == 4: z, logdet = fourier_conv('fourier', z, logdet) elif hps.flow_permutation == 5: z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition) z, logdet = maf_three('maf1', z, logdet, depth=96, is_upper=False) z, logdet = maf_three('maf2', z, logdet, depth=96, is_upper=True) else: raise Exception() z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 += f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) logscale = tf.log_sigmoid(h[:, :, :, 1::2] + 2.) z2 += shift z2 *= scale logdet += tf.reduce_sum(logscale, axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) else: z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 -= f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) logscale = tf.log_sigmoid(h[:, :, :, 1::2] + 2.) z2 /= scale z2 -= shift logdet -= tf.reduce_sum(logscale, axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z, reverse=True) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z, reverse=True) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True, decomposition=hps.decomposition) elif hps.flow_permutation == 3: z, logdet = invertible_conv2D_emerging("emerging", z, logdet, reverse=True) z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True, decomposition=hps.decomposition) elif hps.flow_permutation == 4: z, logdet = fourier_conv('fourier', z, logdet, reverse=True) elif hps.flow_permutation == 5: z, logdet = maf_three('maf2', z, logdet, depth=96, is_upper=True, reverse=True) z, logdet = maf_three('maf1', z, logdet, depth=96, is_upper=False, reverse=True) z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition, reverse=True) else: raise Exception() z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True) return z, logdet