def ops(): hvd.init() by2 = tf.random.uniform((2, 2)) tallreduce_mean_sum = (tfops.allreduce_sum(by2) * tfops.allreduce_mean(by2)) / sum(by2.shape) tfops.int_shape(by2) tfops.actnorm("by2", by2)
def _flow_step(self, name, z, logdet): with tf.compat.v1.variable_scope(name): z, logdet = Z.actnorm('actnorm', z, logdet) z, logdet = Z.invertible_1x1_conv('invconv', z, logdet, reverse=False) z1, z2 = Z.split(z) z2 += Z.f('f', z1, self.hps.window_size, self.hps.width) z = Z.unsplit(z1, z2) return z, logdet
def revnet2d_step(name, z, logdet, hps, reverse): with tf.variable_scope(name): shape = Z.int_shape(z) n_z = shape[3] assert n_z % 2 == 0 if not reverse: z, logdet = Z.actnorm("actnorm", z, logdet=logdet) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv("invconv", z, logdet) else: raise Exception() z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 += f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) z2 += shift z2 *= scale logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) else: z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 -= f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) z2 /= scale z2 -= shift logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z, reverse=True) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z, reverse=True) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True) else: raise Exception() z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True) return z, logdet
def revnet2d_step(name, z, logdet, hps, reverse): with tf.variable_scope(name): shape = Z.int_shape(z) n_z = shape[3] assert n_z % 2 == 0 if not reverse: z, logdet = Z.actnorm("actnorm", z, logdet=logdet) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition) elif hps.flow_permutation == 3: z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition) z, logdet = invertible_conv2D_emerging( "emerging", z, logdet, checkpoint_fn=checkpoint) elif hps.flow_permutation == 4: z, logdet = fourier_conv('fourier', z, logdet) elif hps.flow_permutation == 5: z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition) z, logdet = maf_three('maf1', z, logdet, depth=96, is_upper=False) z, logdet = maf_three('maf2', z, logdet, depth=96, is_upper=True) else: raise Exception() z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 += f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) logscale = tf.log_sigmoid(h[:, :, :, 1::2] + 2.) z2 += shift z2 *= scale logdet += tf.reduce_sum(logscale, axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) else: z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 -= f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) logscale = tf.log_sigmoid(h[:, :, :, 1::2] + 2.) z2 /= scale z2 -= shift logdet -= tf.reduce_sum(logscale, axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z, reverse=True) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z, reverse=True) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True, decomposition=hps.decomposition) elif hps.flow_permutation == 3: z, logdet = invertible_conv2D_emerging("emerging", z, logdet, reverse=True) z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True, decomposition=hps.decomposition) elif hps.flow_permutation == 4: z, logdet = fourier_conv('fourier', z, logdet, reverse=True) elif hps.flow_permutation == 5: z, logdet = maf_three('maf2', z, logdet, depth=96, is_upper=True, reverse=True) z, logdet = maf_three('maf1', z, logdet, depth=96, is_upper=False, reverse=True) z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition, reverse=True) else: raise Exception() z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True) return z, logdet