def checkpoint(z, logdet): zshape = Z.int_shape(z) z = tf.reshape(z, [-1, zshape[1]*zshape[2]*zshape[3]]) logdet = tf.reshape(logdet, [-1, 1]) combined = tf.concat([z, logdet], axis=1) tf.add_to_collection('checkpoints', combined) logdet = combined[:, -1] z = tf.reshape(combined[:, :-1], [-1, zshape[1], zshape[2], zshape[3]]) return z, logdet
def split2d(name, z, objective=0.): with tf.variable_scope(name): n_z = Z.int_shape(z)[3] z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] pz = split2d_prior(z1) objective += pz.logp(z2) z1 = Z.squeeze2d(z1) eps = pz.get_eps(z2) return z1, objective, eps
def _f_loss(x, y, is_training, reuse=False): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins) objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective, _ = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _ = prior("prior", y_onehot, hps) objective += logp(z) # Generative loss nobj = - objective bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int( x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel # Predictive loss if hps.weight_y > 0 and hps.ycond: # Classification loss h_y = tf.reduce_mean(z, axis=[1, 2]) y_logits = Z.linear_zeros("classifier", h_y, hps.n_y) bits_y = tf.nn.softmax_cross_entropy_with_logits_v2( labels=y_onehot, logits=y_logits) / np.log(2.) # Classification accuracy y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32) classification_error = 1 - \ tf.cast(tf.equal(y_predicted, y), tf.float32) else: bits_y = tf.zeros_like(bits_x) classification_error = tf.ones_like(bits_x) return bits_x, bits_y, classification_error
def f_encode(x, y, reuse=True): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins) objective += -np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective, eps = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _eps = prior("prior", y_onehot, hps) objective += logp(z) eps.append(_eps(z)) return eps
def f_encode(x, y, reuse=True): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins) objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective, eps = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _eps = prior("prior", y_onehot, hps) objective += logp(z) eps.append(_eps(z)) return eps
def split3d(name, level, z, y_onehot, z_prior=None, objective=0.): with tf.variable_scope(name + str(level)): n_z = Z.int_shape(z)[4] z1 = z[:, :, :, :, :n_z // 2] z2 = z[:, :, :, :, n_z // 2:] shape = [tf.shape(z1)[0]] + Z.int_shape(z1)[1:] ############################# # z_p = z1 # if z_prior is not None: # n_z_prior = Z.int_shape(z_prior)[3] # n_z_p = Z.int_shape(z_p)[3] # # w = tf.get_variable("W_split", [1, 1, n_z_prior, n_z_p], tf.float32, # # initializer=tf.zeros_initializer()) # # z_p -= tf.nn.conv2d(z_prior, w, strides=[1, 1, 1, 1], padding='SAME')###########!!!!!!!!!!####### + or - ## # # z_p -= Z.conv2d_zeros('p_o', z_prior, n_z_prior, n_z_p) # z_p += Z.myMLP(3, z_prior, n_z_prior, n_z_p) ############################# pz = split3d_prior(y_onehot, shape, z_prior, level) objective += pz.logp(z2) z1 = Z.squeeze3d(z1) eps = pz.get_eps(z2) return z1, z2, objective, eps,
def split2d_prior(z, hps): shape = Z.int_shape(z) n_z2 = int(z.get_shape()[3]) n_z1 = n_z2 h = tf.zeros([tf.shape(z)[0]] + shape[1:3] + [2 * n_z1]) if hps.learnprior: h = Z.conv2d_zeros("conv", z, 2 * n_z1) mean = h[:, :, :, 0::2] logs = h[:, :, :, 1::2] return Z.gaussian_diag(mean, logs)
def split3d_reverse(name, level, z, y_onehot, z_provided, eps, eps_std, z_prior=None): with tf.variable_scope(name + str(level)): z1 = Z.unsqueeze3d(z) # n_z = Z.int_shape(z1)[3] shape = [tf.shape(z1)[0]] + Z.int_shape(z1)[1:] # z_p = z1 ############################# # if z_prior is not None: # #z_prior = Z.unsqueeze2d(z_prior) # n_z_prior = Z.int_shape(z_prior)[3] # # w = tf.get_variable("W_split", [1, 1, n_z_prior, n_z], tf.float32, # # initializer=tf.zeros_initializer()) # # z_p -= tf.nn.conv2d(z_prior, w, strides=[1, 1, 1, 1], padding='SAME') ###########!!!!!!!!!!####### + or - ## # # z_p += Z.myMLP(3, z_prior, n_z_prior, n_z) # ############################# pz = split3d_prior(y_onehot, shape, z_prior, level) if z_provided is not None: y_onehot2 = (y_onehot - 0.5) * (-1) + 0.5 # y_onehot = tf.zeros_like(y_onehot) # y_onehot2 = tf.ones_like(y_onehot) pz2_ = split3d_prior(y_onehot2, shape, z_prior, level) # z2 = z_provided + pz.mean - pz2_.mean z2 = z_provided - pz.mean + pz2_.mean #+ 0.5 * (pz.logsd - pz2_.logsd) # z2 = pz2_.sample2(pz.get_eps(z_provided * 0.5)) #pz2_.mean + 0.6 * tf.exp(pz2_.logsd) else: if eps is not None: # Already sampled eps z2 = pz.sample2(eps) elif eps_std is not None: # Sample with given eps_std z2 = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1, 1])) else: # Sample normally z2 = pz.sample z = tf.concat([z1, z2], 4) return z
def test_derivative_fourier_conv(): print('Testing gradients') shape = [128, 32, 32, 3] x = tf.placeholder(tf.float32, shape, name='image') x_np = np.random.randn(*shape).astype('float32') logdet = tf.zeros_like(x)[:, 0, 0, 0] with tf.variable_scope('test'): z = x z, logdet = fourier_conv('layer', z, logdet, reverse=False) with tf.variable_scope('test', reuse=True): w = tf.get_variable('layer/W') f = tf.reduce_sum(logdet) grad = tf.gradients(f, w) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) w_np = sess.run(w) grad_np = sess.run(grad, feed_dict={x: x_np}) delta = 0.0001 v = np.random.randn(*Z.int_shape(w)) finite_diff = (sess.run(f, feed_dict={ x: x_np, w: w_np + delta * v }) - sess.run(f, feed_dict={ x: x_np, w: w_np - delta * v })) / 2 / delta other_side = np.sum(grad_np * v) print(finite_diff, other_side, finite_diff - other_side) print(finite_diff / other_side)
def _f_loss(x_A, y_A, x_B, y_B, is_training, reuse=False, init=False): with tf.variable_scope('model_A', reuse=reuse): y_onehot_A = tf.cast(tf.one_hot(y_A, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective_A = tf.zeros_like(x_A, dtype='float32')[:, 0, 0, 0] z_A = preprocess(x_A) z_A = z_A + tf.random_uniform(tf.shape(z_A), 0, 1./hps.n_bins) objective_A += - np.log(hps.n_bins) * np.prod(Z.int_shape(z_A)[1:]) # Encode z_A = Z.squeeze2d(z_A, 2) # > 16x16x12 z_A, objective_A, eps_A = encoder_A(z_A, objective_A) # Prior hps.top_shape = Z.int_shape(z_A)[1:] logp_A, _, _eps_A = prior("prior", y_onehot_A, hps) objective_A += logp_A(z_A) # Note that we learn the top layer so need to process z z_A = _eps_A(z_A) eps_A.append(z_A) # Loss of eps and flatten latent code from another model eps_flatten_A = tf.concat( [tf.contrib.layers.flatten(e) for e in eps_A], axis=-1) with tf.variable_scope('model_B', reuse=reuse): y_onehot_B = tf.cast(tf.one_hot(y_B, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective_B = tf.zeros_like(x_B, dtype='float32')[:, 0, 0, 0] z_B = preprocess(x_B) z_B = z_B + tf.random_uniform(tf.shape(z_B), 0, 1./hps.n_bins) objective_B += - np.log(hps.n_bins) * np.prod(Z.int_shape(z_B)[1:]) # Encode z_B = Z.squeeze2d(z_B, 2) # > 16x16x12 z_B, objective_B, eps_B = encoder_B(z_B, objective_B) # Prior hps.top_shape = Z.int_shape(z_B)[1:] logp_B, _, _eps_B = prior("prior", y_onehot_B, hps) objective_B += logp_B(z_B) # Note that we learn the top layer so need to process z z_B = _eps_B(z_B) eps_B.append(z_B) # Loss of eps and flatten latent code from another model eps_flatten_B = tf.concat( [tf.contrib.layers.flatten(e) for e in eps_B], axis=-1) code_loss = 0.0 code_shapes = [[16, 16, 6], [8, 8, 12], [4, 4, 48]] if hps.code_loss_type == 'B_all': if not init: """ Decode the code from another model and compute L2 loss at pixel level """ def unflatten_code(fcode, code_shapes): index = 0 code = [] bs = tf.shape(fcode)[0] # bs = hps.local_batch_train for shape in code_shapes: code.append(tf.reshape(fcode[:, index:index+np.prod(shape)], tf.convert_to_tensor([bs] + shape))) index += np.prod(shape) return code code_others = unflatten_code(eps_flatten_A, code_shapes) # code_others[-1] is z, and code_others[:-1] is eps with tf.variable_scope('model_B', reuse=True): _, sample, _ = prior("prior", y_onehot_B, hps) code_last_others = sample(eps=code_others[-1]) code_decoded_others = decoder_B( code_last_others, code_others[:-1]) code_decoded = Z.unsqueeze2d(code_decoded_others, 2) x_B_recon = postprocess(code_decoded) x_B_scaled = 1/255.0 * tf.cast(x_B, tf.float32) x_B_recon_scaled = 1/255.0 * tf.cast(x_B_recon, tf.float32) if hps.code_loss_fn == 'l1': code_loss = tf.reduce_mean(tf.losses.absolute_difference( x_B_scaled, x_B_recon_scaled)) elif hps.code_loss_fn == 'l2': code_loss = tf.reduce_mean(tf.squared_difference( x_B_scaled, x_B_recon_scaled)) else: raise NotImplementedError() elif hps.code_loss_type == 'code_all': code_loss = tf.reduce_mean( tf.squared_difference(eps_flatten_A, eps_flatten_B)) elif hps.code_loss_type == 'code_last': dim = np.prod(code_shapes[-1]) code_loss = tf.reduce_mean(tf.squared_difference( eps_flatten_A[:, -dim:], eps_flatten_B[:, -dim:])) else: raise NotImplementedError() with tf.variable_scope('model_A', reuse=True): # Generative loss nobj_A = - objective_A bits_x_A = nobj_A / (np.log(2.) * int(x_A.get_shape()[1]) * int( x_A.get_shape()[2]) * int(x_A.get_shape()[3])) # bits per subpixel bits_y_A = tf.zeros_like(bits_x_A) classification_error_A = tf.ones_like(bits_x_A) with tf.variable_scope('model_B', reuse=True): # Generative loss nobj_B = - objective_B bits_x_B = nobj_B / (np.log(2.) * int(x_B.get_shape()[1]) * int( x_B.get_shape()[2]) * int(x_B.get_shape()[3])) # bits per subpixel bits_y_B = tf.zeros_like(bits_x_B) classification_error_B = tf.ones_like(bits_x_B) return (bits_x_A, bits_y_A, classification_error_A, eps_flatten_A, bits_x_B, bits_y_B, classification_error_B, eps_flatten_B, code_loss)
def invertible_conv2D_emerging(name, z, logdet, ksize=3, dilation=1, reverse=False, checkpoint_fn=None): batchsize, height, width, n_channels = Z.int_shape(z) assert (ksize - 1) % 2 == 0 kcent = (ksize - 1) // 2 with tf.variable_scope(name): mask_np = get_conv_square_ar_mask( ksize, ksize, n_channels, n_channels, zerodiagonal=True)[::-1, ::-1, ::-1, ::-1].copy() mask = tf.constant(mask_np) print(mask_np.transpose(3, 2, 0, 1)) filter_shape = [ksize, ksize, n_channels, n_channels] w1_np = get_conv_weight_np(filter_shape) w2_np = get_conv_weight_np(filter_shape) w1 = tf.get_variable('W1', dtype=tf.float32, initializer=w1_np) w2 = tf.get_variable('W2', dtype=tf.float32, initializer=w2_np) b = tf.get_variable('b', [n_channels], initializer=tf.zeros_initializer()) b = tf.reshape(b, [1, 1, 1, -1]) w1 = w1 * mask w2 = w2 * mask s_np = (1 + np.random.randn(n_channels) * 0.02).astype('float32') s = tf.get_variable('scale', dtype=tf.float32, initializer=s_np) s = tf.reshape(s, [1, 1, 1, n_channels]) def flat(z): return tf.reshape(z, [batchsize, height * width * n_channels]) def unflat(z): return tf.reshape(z, [batchsize, height, width, n_channels]) def shift_and_log_scale_fn_volume_preserving_1(z_flat): z = unflat(z_flat) shift = tf.nn.conv2d(z, w1, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='SAME', data_format='NHWC') shift_flat = flat(shift) return shift_flat, tf.zeros_like(shift_flat) def shift_and_log_scale_fn_volume_preserving_2(z_flat): z = unflat(z_flat) shift = tf.nn.conv2d(z, w2, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='SAME', data_format='NHWC') shift_flat = flat(shift) return shift_flat, tf.zeros_like(shift_flat) flow1 = tfb.MaskedAutoregressiveFlow( shift_and_log_scale_fn_volume_preserving_1) flow2 = tfb.MaskedAutoregressiveFlow( shift_and_log_scale_fn_volume_preserving_2) def flip(z_flat): z = unflat(z_flat) z = z[:, ::-1, ::-1, ::-1] z = flat(z) return z def forward(z, logdet): z = z * s logdet += tf.reduce_sum(tf.log(tf.abs(s))) * (height * width) z_flat = flat(z) z_flat = flow1.forward(z_flat) z_flat = flip(z_flat) z_flat = flow2.forward(z_flat) z_flat = flip(z_flat) z = unflat(z_flat) z = z + b return z, logdet def inverse(z, logdet): z = z - b z_flat = flat(z) z_flat = flip(z_flat) z_flat = flow2.inverse(z_flat) z_flat = flip(z_flat) z_flat = flow1.inverse(z_flat) z = unflat(z_flat) z = z / s logdet -= tf.reduce_sum(tf.log(tf.abs(s))) * (height * width) z = unflat(z) return z, logdet if not reverse: x, logdet = forward(z, logdet) return x, logdet else: x, logdet = inverse(z, logdet) return x, logdet
def fourier_conv(name, z, logdet, ksize=3, reverse=False, checkpoint_fn=None, use_fourier_forward=False): batchsize, height, width, n_channels = Z.int_shape(z) assert (ksize - 1) % 2 == 0 with tf.variable_scope(name): filter_shape = [ksize, ksize, n_channels, n_channels] w_np = get_conv_weight_np(filter_shape) w = tf.get_variable('W', dtype=tf.float32, initializer=w_np) b = tf.get_variable('b', [n_channels], initializer=tf.zeros_initializer()) b = tf.reshape(b, [1, 1, 1, -1]) f_shape = [height, width] def forward(z, w, logdet): padsize = (ksize - 1) // 2 # Circular padding. z = tf.concat((z[:, -padsize:, :], z, z[:, :padsize, :]), axis=1) z = tf.concat((z[:, :, -padsize:], z, z[:, :, :padsize]), axis=2) # Circular convolution (due to padding.) z = tf.nn.conv2d(z, w, [1, 1, 1, 1], padding='VALID', data_format='NHWC') # Fourier transform for log determinant. w_fft = tf.spectral.rfft2d(tf.transpose( w, [3, 2, 0, 1])[:, :, ::-1, ::-1], fft_length=f_shape, name=None) dlogdet = compute_logdet(w_fft, width) logdet += dlogdet z = z + b return z, logdet def forward_fourier(x, w, logdet): # Dimension [b, c, v, u] x_fft = tf.spectral.rfft2d(tf.transpose(x, [0, 3, 1, 2]), fft_length=f_shape, name=None) # Dimension [b, 1, c_in, v, u] x_fft = tf.expand_dims(x_fft, 1) # Dimension [c_out, c_in, v, u] w_fft = tf.spectral.rfft2d(tf.transpose( w, [3, 2, 0, 1])[:, :, ::-1, ::-1], fft_length=f_shape, name=None) logdet += compute_logdet(w_fft, width) # Dimension [1, c_out, c_in, v, u] w_fft = tf.expand_dims(w_fft, 0) z_fft = tf.reduce_sum(tf.multiply(x_fft, w_fft), axis=2) z = tf.spectral.irfft2d( z_fft, fft_length=f_shape, ) z = tf.transpose(z, [0, 2, 3, 1]) z = reindex(z) z = z + b return z, logdet def inverse(z, logdet): z = z - b z = reindex(z, reverse=True) # Dimension [b, c_out, v, u] z_fft = tf.spectral.rfft2d(tf.transpose(z, [0, 3, 1, 2]), fft_length=f_shape, name=None) # Dimension [b, 1, c_out, v, u] z_fft = tf.expand_dims(z_fft, 1) # Dimension [c_out, c_in, v, u] w_fft = tf.spectral.rfft2d(tf.transpose( w, [3, 2, 0, 1])[:, :, ::-1, ::-1], fft_length=f_shape, name=None) dlogdet = compute_logdet(w_fft, width) # z_fft = tf.Print( # z_fft, data=[dlogdet / height / width], message='dlogdet:') logdet -= dlogdet # Dimension [v, u, c_in, c_out], channels switched because of # inverse. w_fft_inv = tf.linalg.inv(tf.transpose(w_fft, [2, 3, 0, 1]), ) # Dimension [c_in, c_out, v, u] w_fft_inv = tf.transpose(w_fft_inv, [2, 3, 0, 1]) # Dimension [1, c_in, c_out, v, u] w_fft_inv = tf.expand_dims(w_fft_inv, 0) x_fft = tf.reduce_sum(tf.multiply(z_fft, w_fft_inv), axis=2) x = tf.spectral.irfft2d( x_fft, fft_length=f_shape, ) x = tf.transpose(x, [0, 2, 3, 1]) return x, logdet if not reverse: x = z if use_fourier_forward: z, logdet = forward_fourier(x, w, logdet) else: z, logdet = forward(x, w, logdet) return z, logdet else: z, logdet = inverse(z, logdet) return z, logdet
def revnet2d_step(name, z, logdet, hps, reverse): with tf.variable_scope(name): shape = Z.int_shape(z) n_z = shape[3] assert n_z % 2 == 0 if not reverse: z, logdet = Z.actnorm("actnorm", z, logdet=logdet) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition) elif hps.flow_permutation == 3: z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition) z, logdet = invertible_conv2D_emerging( "emerging", z, logdet, checkpoint_fn=checkpoint) elif hps.flow_permutation == 4: z, logdet = fourier_conv('fourier', z, logdet) elif hps.flow_permutation == 5: z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition) z, logdet = maf_three('maf1', z, logdet, depth=96, is_upper=False) z, logdet = maf_three('maf2', z, logdet, depth=96, is_upper=True) else: raise Exception() z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 += f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) logscale = tf.log_sigmoid(h[:, :, :, 1::2] + 2.) z2 += shift z2 *= scale logdet += tf.reduce_sum(logscale, axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) else: z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 -= f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) logscale = tf.log_sigmoid(h[:, :, :, 1::2] + 2.) z2 /= scale z2 -= shift logdet -= tf.reduce_sum(logscale, axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z, reverse=True) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z, reverse=True) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True, decomposition=hps.decomposition) elif hps.flow_permutation == 3: z, logdet = invertible_conv2D_emerging("emerging", z, logdet, reverse=True) z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True, decomposition=hps.decomposition) elif hps.flow_permutation == 4: z, logdet = fourier_conv('fourier', z, logdet, reverse=True) elif hps.flow_permutation == 5: z, logdet = maf_three('maf2', z, logdet, depth=96, is_upper=True, reverse=True) z, logdet = maf_three('maf1', z, logdet, depth=96, is_upper=False, reverse=True) z, logdet = invertible_1x1_conv( "invconv", z, logdet, decomposition=hps.decomposition, reverse=True) else: raise Exception() z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True) return z, logdet
def invertible_1x1_conv(name, z, logdet, reverse=False): if True: # Set to "False" to use the LU-decomposed version with tf.variable_scope(name): shape = Z.int_shape(z) w_shape = [shape[3], shape[3]] # Sample a random orthogonal matrix: w_init = np.linalg.qr(np.random.randn( *w_shape))[0].astype('float32') w = tf.get_variable("W", dtype=tf.float32, initializer=w_init) # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2] dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant( tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2] if not reverse: _w = tf.reshape(w, [1, 1] + w_shape) z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += dlogdet return z, logdet else: _w = tf.matrix_inverse(w) _w = tf.reshape(_w, [1, 1]+w_shape) z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= dlogdet return z, logdet else: # LU-decomposed version shape = Z.int_shape(z) with tf.variable_scope(name): dtype = 'float64' # Random orthogonal matrix: import scipy np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[ 0].astype('float32') np_p, np_l, np_u = scipy.linalg.lu(np_w) np_s = np.diag(np_u) np_sign_s = np.sign(np_s) np_log_s = np.log(abs(np_s)) np_u = np.triu(np_u, k=1) p = tf.get_variable("P", initializer=np_p, trainable=False) l = tf.get_variable("L", initializer=np_l) sign_s = tf.get_variable( "sign_S", initializer=np_sign_s, trainable=False) log_s = tf.get_variable("log_S", initializer=np_log_s) # S = tf.get_variable("S", initializer=np_s) u = tf.get_variable("U", initializer=np_u) p = tf.cast(p, dtype) l = tf.cast(l, dtype) sign_s = tf.cast(sign_s, dtype) log_s = tf.cast(log_s, dtype) u = tf.cast(u, dtype) w_shape = [shape[3], shape[3]] l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1) l = l * l_mask + tf.eye(*w_shape, dtype=dtype) u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s)) w = tf.matmul(p, tf.matmul(l, u)) if True: u_inv = tf.matrix_inverse(u) l_inv = tf.matrix_inverse(l) p_inv = tf.matrix_inverse(p) w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv)) else: w_inv = tf.matrix_inverse(w) w = tf.cast(w, tf.float32) w_inv = tf.cast(w_inv, tf.float32) log_s = tf.cast(log_s, tf.float32) if not reverse: w = tf.reshape(w, [1, 1] + w_shape) z = tf.nn.conv2d(z, w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2]) return z, logdet else: w_inv = tf.reshape(w_inv, [1, 1]+w_shape) z = tf.nn.conv2d( z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2]) return z, logdet
def invertible_1x1_conv(name, z, logdet, decomposition=None, reverse=False, unit_testing=False): shape = Z.int_shape(z) w_shape = [shape[3], shape[3]] if decomposition is None or decomposition == '': with tf.variable_scope(name): # Sample a random orthogonal matrix: w_init = np.linalg.qr( np.random.randn(*w_shape))[0].astype('float32') w = tf.get_variable("W", dtype=tf.float32, initializer=w_init) # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2] dlogdet = tf.cast( tf.log(abs(tf.matrix_determinant(tf.cast(w, 'float64')))), 'float32') * shape[1] * shape[2] if not reverse: _w = tf.reshape(w, [1, 1] + w_shape) z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += dlogdet return z, logdet else: # z = tf.Print( # z, # data=[dlogdet / shape[1] / shape[2]], # message='logdet invconv foreach spatial location: ') _w = tf.matrix_inverse(w) _w = tf.reshape(_w, [1, 1] + w_shape) z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= dlogdet return z, logdet elif decomposition == 'PLU' or decomposition == 'LU': # LU-decomposed version shape = Z.int_shape(z) with tf.variable_scope(name): dtype = 'float64' # Random orthogonal matrix: import scipy np_w = scipy.linalg.qr(np.random.randn( shape[3], shape[3]))[0].astype('float32') np_p, np_l, np_u = scipy.linalg.lu(np_w) np_s = np.diag(np_u) np_sign_s = np.sign(np_s) np_log_s = np.log(abs(np_s)) np_u = np.triu(np_u, k=1) p = tf.get_variable("P", initializer=np_p, trainable=False) l = tf.get_variable("L", initializer=np_l) sign_s = tf.get_variable("sign_S", initializer=np_sign_s, trainable=False) log_s = tf.get_variable("log_S", initializer=np_log_s) # S = tf.get_variable("S", initializer=np_s) u = tf.get_variable("U", initializer=np_u) p = tf.cast(p, dtype) l = tf.cast(l, dtype) sign_s = tf.cast(sign_s, dtype) log_s = tf.cast(log_s, dtype) u = tf.cast(u, dtype) l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1) l = l * l_mask + tf.eye(*w_shape, dtype=dtype) u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s)) w = tf.matmul(p, tf.matmul(l, u)) if True: u_inv = tf.matrix_inverse(u) l_inv = tf.matrix_inverse(l) p_inv = tf.matrix_inverse(p) w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv)) else: w_inv = tf.matrix_inverse(w) w = tf.cast(w, tf.float32) w_inv = tf.cast(w_inv, tf.float32) log_s = tf.cast(log_s, tf.float32) if not reverse: w = tf.reshape(w, [1, 1] + w_shape) z = tf.nn.conv2d(z, w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += tf.reduce_sum(log_s) * (shape[1] * shape[2]) return z, logdet else: w_inv = tf.reshape(w_inv, [1, 1] + w_shape) z = tf.nn.conv2d(z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= tf.reduce_sum(log_s) * (shape[1] * shape[2]) return z, logdet elif decomposition == 'QR': with tf.variable_scope(name): np_s = np.ones(shape[3], dtype='float32') np_u = np.zeros((shape[3], shape[3]), dtype='float32') if unit_testing: np_s = 1 + 0.02 * np.random.randn(shape[3]).astype('float32') np_u = np.random.randn(shape[3], shape[3]).astype('float32') np_u = np.triu(np_u, k=1).astype('float32') u_mask = np.triu(np.ones(w_shape, dtype='float32'), 1) s = tf.get_variable("S", initializer=np_s) u = tf.get_variable("U", initializer=np_u) log_s = tf.log(tf.abs(s)) r = u * u_mask + tf.diag(s) # Householder transformations I = tf.eye(shape[3]) q = I for i in range(shape[3]): v_np = np.random.randn(shape[3], 1).astype('float32') v = tf.get_variable("v_{}".format(i), initializer=v_np) vT = tf.transpose(v) q_i = I - 2 * tf.matmul(v, vT) / tf.matmul(vT, v) q = tf.matmul(q, q_i) # Modified Gram–Schmidt process # def inner(a, b): # return tf.reduce_sum(a * b) # def proj(v, u): # return u * inner(v, u) / inner(u, u) # q = [] # for i in range(shape[3]): # v_np = np.random.randn(shape[3], 1).astype('float32') # v = tf.get_variable("v_{}".format(i), initializer=v_np) # for j in range(i): # p = proj(v, q[j]) # v = v - proj(v, q[j]) # q.append(v) # q = tf.concat(q, axis=1) # q = q / tf.norm(q, axis=0, keepdims=True) q_inv = tf.transpose(q) r_inv = tf.matrix_inverse(r) w = tf.matmul(q, r) w_inv = tf.matmul(r_inv, q_inv) if not reverse: w = tf.reshape(w, [1, 1] + w_shape) z = tf.nn.conv2d(z, w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += tf.reduce_sum(log_s) * (shape[1] * shape[2]) return z, logdet else: w_inv = tf.reshape(w_inv, [1, 1] + w_shape) z = tf.nn.conv2d(z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= tf.reduce_sum(log_s) * (shape[1] * shape[2]) return z, logdet else: raise ValueError('Unkown decomposition: {}'.format(decomposition))
def revnet2d_step(name, z, logdet, hps, reverse): with tf.variable_scope(name): shape = Z.int_shape(z) n_z = shape[3] assert n_z % 2 == 0 if not reverse: z, logdet = Z.actnorm("actnorm", z, logdet=logdet) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv("invconv", z, logdet) else: raise Exception() z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 += f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) z2 += shift z2 *= scale logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) else: z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 -= f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) z2 /= scale z2 -= shift logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) else: raise Exception() z = tf.concat([z1, z2], 3) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z, reverse=True) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z, reverse=True) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True) else: raise Exception() z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True) return z, logdet
def invertible_conv2D_emerging_1x1(name, z, logdet, ksize=3, dilation=1, reverse=False, checkpoint_fn=None, decomposition=None, unit_testing=False): shape = Z.int_shape(z) batchsize, height, width, n_channels = shape assert (ksize - 1) % 2 == 0 kcent = (ksize - 1) // 2 with tf.variable_scope(name): if decomposition is None or decomposition == '': # Sample a random orthogonal matrix: w_init = np.linalg.qr(np.random.randn( shape[3], shape[3]))[0].astype('float32') w = tf.get_variable("W", dtype=tf.float32, initializer=w_init) dlogdet = tf.cast( tf.log(abs(tf.matrix_determinant(tf.cast(w, 'float64')))), 'float32') * shape[1] * shape[2] w_inv = tf.matrix_inverse(w) elif decomposition == 'PLU' or decomposition == 'LU': # LU-decomposed version dtype = 'float64' # Random orthogonal matrix: import scipy np_w = scipy.linalg.qr(np.random.randn( shape[3], shape[3]))[0].astype('float32') np_p, np_l, np_u = scipy.linalg.lu(np_w) np_s = np.diag(np_u) np_sign_s = np.sign(np_s) np_log_s = np.log(abs(np_s)) np_u = np.triu(np_u, k=1) p = tf.get_variable("P", initializer=np_p, trainable=False) l = tf.get_variable("L", initializer=np_l) sign_s = tf.get_variable("sign_S", initializer=np_sign_s, trainable=False) log_s = tf.get_variable("log_S", initializer=np_log_s) u = tf.get_variable("U", initializer=np_u) p = tf.cast(p, 'float64') l = tf.cast(l, 'float64') sign_s = tf.cast(sign_s, 'float64') log_s = tf.cast(log_s, 'float64') u = tf.cast(u, 'float64') l_mask = np.tril(np.ones([shape[3], shape[3]], dtype=dtype), -1) l = l * l_mask + tf.eye(shape[3], dtype=dtype) u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s)) w = tf.matmul(p, tf.matmul(l, u)) u_inv = tf.matrix_inverse(u) l_inv = tf.matrix_inverse(l) p_inv = tf.matrix_inverse(p) w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv)) w = tf.cast(w, tf.float32) w_inv = tf.cast(w_inv, tf.float32) log_s = tf.cast(log_s, tf.float32) dlogdet = tf.reduce_sum(log_s) * (shape[1] * shape[2]) elif decomposition == 'QR': np_s = np.ones(shape[3], dtype='float32') np_u = np.zeros((shape[3], shape[3]), dtype='float32') if unit_testing: np_s = 1 + 0.02 * np.random.randn(shape[3]).astype('float32') np_u = np.random.randn(shape[3], shape[3]).astype('float32') np_u = np.triu(np_u, k=1).astype('float32') u_mask = np.triu(np.ones([shape[3], shape[3]], dtype='float32'), 1) s = tf.get_variable("S", initializer=np_s) u = tf.get_variable("U", initializer=np_u) log_s = tf.log(tf.abs(s)) r = u * u_mask + tf.diag(s) # Householder transformations I = tf.eye(shape[3]) q = I for i in range(shape[3]): v_np = np.random.randn(shape[3], 1).astype('float32') v = tf.get_variable("v_{}".format(i), initializer=v_np) vT = tf.transpose(v) q_i = I - 2 * tf.matmul(v, vT) / tf.matmul(vT, v) q = tf.matmul(q, q_i) # Modified Gram–Schmidt process # def inner(a, b): # return tf.reduce_sum(a * b) # def proj(v, u): # return u * inner(v, u) / inner(u, u) # q = [] # for i in range(shape[3]): # v_np = np.random.randn(shape[3], 1).astype('float32') # v = tf.get_variable("v_{}".format(i), initializer=v_np) # for j in range(i): # p = proj(v, q[j]) # v = v - proj(v, q[j]) # q.append(v) # q = tf.concat(q, axis=1) # q = q / tf.norm(q, axis=0, keepdims=True) q_inv = tf.transpose(q) r_inv = tf.matrix_inverse(r) w = tf.matmul(q, r) w_inv = tf.matmul(r_inv, q_inv) dlogdet = tf.reduce_sum(log_s) * (shape[1] * shape[2]) else: raise ValueError('Unknown decomposition: {}'.format(decomposition)) mask_np = get_conv_square_ar_mask(ksize, ksize, n_channels, n_channels) mask_upsidedown_np = mask_np[::-1, ::-1, ::-1, ::-1].copy() mask = tf.constant(mask_np) mask_upsidedown = tf.constant(mask_upsidedown_np) filter_shape = [ksize, ksize, n_channels, n_channels] w1_np = get_conv_weight_np(filter_shape) w2_np = get_conv_weight_np(filter_shape) w1 = tf.get_variable('W1', dtype=tf.float32, initializer=w1_np) w2 = tf.get_variable('W2', dtype=tf.float32, initializer=w2_np) b = tf.get_variable('b', [n_channels], initializer=tf.zeros_initializer()) b = tf.reshape(b, [1, 1, 1, -1]) w1 = w1 * mask w2 = w2 * mask_upsidedown def log_abs_diagonal(w): return tf.log(tf.abs(tf.diag_part(w[kcent, kcent]))) def forward(z, logdet): w_ = tf.reshape(w, [1, 1] + [shape[3], shape[3]]) z = tf.nn.conv2d(z, w_, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += dlogdet z = tf.nn.conv2d(z, w1, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='SAME', data_format='NHWC') logdet += tf.reduce_sum(log_abs_diagonal(w1)) * (height * width) if checkpoint_fn is not None: checkpoint_fn(z, logdet) z = tf.nn.conv2d(z, w2, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='SAME', data_format='NHWC') logdet += tf.reduce_sum(log_abs_diagonal(w2)) * (height * width) if checkpoint_fn is not None: checkpoint_fn(z, logdet) z = z + b return z, logdet def forward_fast(z, logdet): """ Convolution with [(k+1) // 2]^2 filters. """ # Smaller versions of w1, w2. w1_s = w1[kcent:, kcent:, :, :] w2_s = w2[:-kcent, :-kcent, :, :] pad = kcent * dilation # standard filter shape: [v, u, c_in, c_out] # standard fmap shape: [b, h, w, c] w_ = tf.transpose(tf.reshape(w, [1, 1] + [shape[3], shape[3]]), (0, 1, 3, 2)) w_equiv = tf.nn.conv2d(tf.transpose(w1_s, (3, 0, 1, 2)), w_, [1, 1, 1, 1], padding='SAME') w_equiv = tf.transpose(w_equiv, (1, 2, 3, 0)) z = tf.pad(z, [[0, 0], [0, pad], [0, pad], [0, 0]], 'CONSTANT') z = tf.nn.conv2d(z, w_equiv, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='VALID', data_format='NHWC') logdet += tf.reduce_sum(log_abs_diagonal(w1)) * (height * width) if checkpoint_fn is not None: checkpoint_fn(z, logdet) z = tf.pad(z, [[0, 0], [pad, 0], [pad, 0], [0, 0]], 'CONSTANT') z = tf.nn.conv2d(z, w2_s, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='VALID', data_format='NHWC') logdet += tf.reduce_sum(log_abs_diagonal(w2)) * (height * width) if checkpoint_fn is not None: checkpoint_fn(z, logdet) z = z + b return z, logdet if not reverse: x, logdet = forward_fast(z, logdet) # x_, _ = forward(z, logdet) # x = tf.Print( # x, data=[tf.reduce_mean(tf.square(x - x_))], message='diff') return x, logdet else: logdet -= dlogdet logdet -= tf.reduce_sum(log_abs_diagonal(w2)) * (height * width) x = tf.py_func( Inverse(is_upper=1, dilation=dilation), inp=[z, w2, b], Tout=tf.float32, stateful=True, name='conv2dinverse2', ) logdet -= tf.reduce_sum(log_abs_diagonal(w1)) * (height * width) x = tf.py_func( Inverse(is_upper=0, dilation=dilation), inp=[x, w1, tf.zeros_like(b)], Tout=tf.float32, stateful=True, name='conv2dinverse1', ) x.set_shape(z.get_shape()) z_recon, _ = forward_fast(x, tf.zeros_like(logdet)) w_inv = tf.reshape(w_inv, [1, 1] + [shape[3], shape[3]]) x = tf.nn.conv2d(x, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= dlogdet # mse = tf.sqrt(tf.reduce_mean(tf.pow(z_recon - z, 2))) # x = tf.Print( # x, # data=[mse], # message='RMSE of inverse', # ) return x, logdet
def invertible_ar_conv2D( name, z, logdet, is_upper, ksize=3, dilation=1, reverse=False, ): shape = Z.int_shape(z) n_channels = shape[3] kcent = (ksize - 1) // 2 with tf.variable_scope(name): mask_np = get_conv_ar_mask(ksize, ksize, n_channels, n_channels) if is_upper: mask_np = mask_np[::-1, ::-1, ::-1, ::-1].copy() mask = tf.constant(mask_np) filter_shape = [ksize, ksize, n_channels, n_channels] weight_np = get_conv_weight_np(filter_shape) w = tf.get_variable('W', dtype=tf.float32, initializer=weight_np) b = tf.get_variable('b', [n_channels], initializer=tf.zeros_initializer()) b = tf.reshape(b, [1, 1, 1, -1]) w = mask * w log_abs_diagonal = tf.log(tf.abs(tf.diag_part(w[kcent, kcent]))) if not reverse: z = tf.nn.conv2d(z, w, strides=[1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='SAME', data_format='NHWC') + b logdet += tf.reduce_sum(log_abs_diagonal) * (shape[1] * shape[2]) return z, logdet else: logdet -= tf.reduce_sum(log_abs_diagonal) * (shape[1] * shape[2]) x = tf.py_func( Inverse(is_upper=is_upper, dilation=dilation), inp=[z, w, b], Tout=tf.float32, stateful=True, name='conv2dinverse', ) z_recon = tf.nn.conv2d( x, w, [1, 1, 1, 1], padding='SAME', data_format='NHWC') + b mse = tf.sqrt(tf.reduce_mean(tf.pow(z_recon - z, 2))) x = tf.Print( x, data=[mse], message='RMSE of inverse', ) x.set_shape(z.get_shape()) return x, logdet
def invertible_conv2D_emerging(name, z, logdet, ksize=3, dilation=1, reverse=False, checkpoint_fn=None): batchsize, height, width, n_channels = Z.int_shape(z) assert (ksize - 1) % 2 == 0 kcent = (ksize - 1) // 2 with tf.variable_scope(name): mask_np = get_conv_square_ar_mask(ksize, ksize, n_channels, n_channels) mask_upsidedown_np = mask_np[::-1, ::-1, ::-1, ::-1].copy() mask = tf.constant(mask_np) mask_upsidedown = tf.constant(mask_upsidedown_np) filter_shape = [ksize, ksize, n_channels, n_channels] w1_np = get_conv_weight_np(filter_shape) w2_np = get_conv_weight_np(filter_shape) w1 = tf.get_variable('W1', dtype=tf.float32, initializer=w1_np) w2 = tf.get_variable('W2', dtype=tf.float32, initializer=w2_np) b = tf.get_variable('b', [n_channels], initializer=tf.zeros_initializer()) b = tf.reshape(b, [1, 1, 1, -1]) w1 = w1 * mask w2 = w2 * mask_upsidedown def log_abs_diagonal(w): return tf.log(tf.abs(tf.diag_part(w[kcent, kcent]))) def forward(z, logdet): z = tf.nn.conv2d(z, w1, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='SAME', data_format='NHWC') logdet += tf.reduce_sum(log_abs_diagonal(w1)) * (height * width) if checkpoint_fn is not None: checkpoint_fn(z, logdet) z = tf.nn.conv2d(z, w2, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='SAME', data_format='NHWC') logdet += tf.reduce_sum(log_abs_diagonal(w2)) * (height * width) if checkpoint_fn is not None: checkpoint_fn(z, logdet) z = z + b return z, logdet def forward_fast(z, logdet): """ Convolution with [(k+1) // 2]^2 filters. """ # Smaller versions of w1, w2. w1_s = w1[kcent:, kcent:, :, :] w2_s = w2[:-kcent, :-kcent, :, :] pad = kcent * dilation z = tf.pad(z, [[0, 0], [0, pad], [0, pad], [0, 0]], 'CONSTANT') z = tf.nn.conv2d(z, w1_s, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='VALID', data_format='NHWC') logdet += tf.reduce_sum(log_abs_diagonal(w1)) * (height * width) if checkpoint_fn is not None: checkpoint_fn(z, logdet) z = tf.pad(z, [[0, 0], [pad, 0], [pad, 0], [0, 0]], 'CONSTANT') z = tf.nn.conv2d(z, w2_s, [1, 1, 1, 1], dilations=[1, dilation, dilation, 1], padding='VALID', data_format='NHWC') logdet += tf.reduce_sum(log_abs_diagonal(w2)) * (height * width) if checkpoint_fn is not None: checkpoint_fn(z, logdet) z = z + b return z, logdet if not reverse: x, logdet = forward_fast(z, logdet) return x, logdet else: logdet -= tf.reduce_sum(log_abs_diagonal(w2)) * (height * width) x = tf.py_func( Inverse(is_upper=1, dilation=dilation), inp=[z, w2, b], Tout=tf.float32, stateful=True, name='conv2dinverse2', ) logdet -= tf.reduce_sum(log_abs_diagonal(w1)) * (height * width) x = tf.py_func( Inverse(is_upper=0, dilation=dilation), inp=[x, w1, tf.zeros_like(b)], Tout=tf.float32, stateful=True, name='conv2dinverse1', ) x.set_shape(z.get_shape()) z_recon, _ = forward_fast(x, tf.zeros_like(logdet)) # mse = tf.sqrt(tf.reduce_mean(tf.pow(z_recon - z, 2))) # x = tf.Print( # x, # data=[mse], # message='RMSE of inverse', # ) return x, logdet
def invertible_1x1_conv(name, z, logdet, reverse=False): if True: # Set to "False" to use the LU-decomposed version with tf.variable_scope(name): shape = Z.int_shape(z) w_shape = [shape[3], shape[3]] # Sample a random orthogonal matrix: w_init = np.linalg.qr(np.random.randn( *w_shape))[0].astype('float32') w = tf.get_variable("W", dtype=tf.float32, initializer=w_init) # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2] dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant( tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2] if not reverse: _w = tf.reshape(w, [1, 1] + w_shape) z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += dlogdet return z, logdet else: _w = tf.matrix_inverse(w) _w = tf.reshape(_w, [1, 1]+w_shape) z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= dlogdet return z, logdet else: # LU-decomposed version shape = Z.int_shape(z) with tf.variable_scope(name): dtype = 'float64' # Random orthogonal matrix: import scipy np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[ 0].astype('float32') np_p, np_l, np_u = scipy.linalg.lu(np_w) np_s = np.diag(np_u) np_sign_s = np.sign(np_s) np_log_s = np.log(abs(np_s)) np_u = np.triu(np_u, k=1) p = tf.get_variable("P", initializer=np_p, trainable=False) l = tf.get_variable("L", initializer=np_l) # noqa sign_s = tf.get_variable( "sign_S", initializer=np_sign_s, trainable=False) log_s = tf.get_variable("log_S", initializer=np_log_s) # S = tf.get_variable("S", initializer=np_s) u = tf.get_variable("U", initializer=np_u) p = tf.cast(p, dtype) l = tf.cast(l, dtype) # noqa sign_s = tf.cast(sign_s, dtype) log_s = tf.cast(log_s, dtype) u = tf.cast(u, dtype) w_shape = [shape[3], shape[3]] l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1) l = l * l_mask + tf.eye(*w_shape, dtype=dtype) # noqa u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s)) w = tf.matmul(p, tf.matmul(l, u)) if True: u_inv = tf.matrix_inverse(u) l_inv = tf.matrix_inverse(l) p_inv = tf.matrix_inverse(p) w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv)) else: w_inv = tf.matrix_inverse(w) w = tf.cast(w, tf.float32) w_inv = tf.cast(w_inv, tf.float32) log_s = tf.cast(log_s, tf.float32) if not reverse: w = tf.reshape(w, [1, 1] + w_shape) z = tf.nn.conv2d(z, w, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2]) return z, logdet else: w_inv = tf.reshape(w_inv, [1, 1]+w_shape) z = tf.nn.conv2d( z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC') logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2]) return z, logdet