def checkpoint(z, logdet): zshape = Z.int_shape(z) z = tf.reshape(z, [-1, zshape[1]*zshape[2]*zshape[3]*zshape[4]]) logdet = tf.reshape(logdet, [-1, 1]) combined = tf.concat([z, logdet], axis=1) tf.add_to_collection('checkpoints', combined) logdet = combined[:, -1] z = tf.reshape(combined[:, :-1], [-1, zshape[1], zshape[2], zshape[3], zshape[4]]) return z, logdet
def split3d(name, z, objective=0.): with tf.variable_scope(name): n_z = Z.int_shape(z)[4] z1 = z[:, :, :, :, :n_z // 2] z2 = z[:, :, :, :, n_z // 2:] pz = split3d_prior(z1) objective += pz.logp(z2) z1 = Z.squeeze3d(z1) eps = pz.get_eps(z2) return z1, objective, eps
def _f_loss(x, y, is_training, reuse=False): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins) objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze3d(z, 2) # 8x8x8x1 ==> 4x4x4x8 z, objective, _ = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _ = prior("prior", y_onehot, hps) objective += logp(z) # Generative loss nobj = - objective bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int( x.get_shape()[2]) * int(x.get_shape()[3]) * int(x.get_shape()[4])) # bits per subpixel # Predictive loss if hps.weight_y > 0 and hps.ycond: # Classification loss h_y = tf.reduce_mean(z, axis=[1, 2, 3]) # ??? y_logits = Z.linear_zeros("classifier", h_y, hps.n_y) bits_y = tf.nn.softmax_cross_entropy_with_logits_v2( labels=y_onehot, logits=y_logits) / np.log(2.) # Classification accuracy y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32) classification_error = 1 - \ tf.cast(tf.equal(y_predicted, y), tf.float32) else: bits_y = tf.zeros_like(bits_x) classification_error = tf.ones_like(bits_x) return bits_x, bits_y, classification_error
def f_encode(x, y, reuse=True): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins) objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze3d(z, 2) # > 4x4x4x8 z, objective, eps = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _eps = prior("prior", y_onehot, hps) objective += logp(z) eps.append(_eps(z)) return eps
def invertible_1x1_conv(name, z, logdet, reverse=False): if True: # Set to "False" to use the LU-decomposed version with tf.variable_scope(name): shape = Z.int_shape(z) w_shape = [shape[4], shape[4]] # Sample a random orthogonal matrix: w_init = np.linalg.qr(np.random.randn( *w_shape))[0].astype('float32') w = tf.get_variable("W", dtype=tf.float32, initializer=w_init) # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2] dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant( tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2]*shape[3] if not reverse: _w = tf.reshape(w, [1, 1, 1] + w_shape) z = tf.nn.conv3d(z, _w, [1, 1, 1, 1, 1], 'SAME', data_format='NDHWC') logdet += dlogdet return z, logdet else: _w = tf.matrix_inverse(w) _w = tf.reshape(_w, [1, 1, 1]+w_shape) z = tf.nn.conv3d(z, _w, [1, 1, 1, 1, 1], 'SAME', data_format='NDHWC') logdet -= dlogdet return z, logdet else: # LU-decomposed version shape = Z.int_shape(z) with tf.variable_scope(name): dtype = 'float64' # Random orthogonal matrix: import scipy np_w = scipy.linalg.qr(np.random.randn(shape[4], shape[4]))[ 0].astype('float32') np_p, np_l, np_u = scipy.linalg.lu(np_w) np_s = np.diag(np_u) np_sign_s = np.sign(np_s) np_log_s = np.log(abs(np_s)) np_u = np.triu(np_u, k=1) p = tf.get_variable("P", initializer=np_p, trainable=False) l = tf.get_variable("L", initializer=np_l) sign_s = tf.get_variable( "sign_S", initializer=np_sign_s, trainable=False) log_s = tf.get_variable("log_S", initializer=np_log_s) # S = tf.get_variable("S", initializer=np_s) u = tf.get_variable("U", initializer=np_u) p = tf.cast(p, dtype) l = tf.cast(l, dtype) sign_s = tf.cast(sign_s, dtype) log_s = tf.cast(log_s, dtype) u = tf.cast(u, dtype) w_shape = [shape[4], shape[4]] l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1) l = l * l_mask + tf.eye(*w_shape, dtype=dtype) u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s)) w = tf.matmul(p, tf.matmul(l, u)) if True: u_inv = tf.matrix_inverse(u) l_inv = tf.matrix_inverse(l) p_inv = tf.matrix_inverse(p) w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv)) else: w_inv = tf.matrix_inverse(w) w = tf.cast(w, tf.float32) w_inv = tf.cast(w_inv, tf.float32) log_s = tf.cast(log_s, tf.float32) if not reverse: w = tf.reshape(w, [1, 1, 1] + w_shape) z = tf.nn.conv3d(z, w, [1, 1, 1, 1, 1], 'SAME', data_format='NDHWC') logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2]*shape[3]) return z, logdet else: w_inv = tf.reshape(w_inv, [1, 1, 1]+w_shape) z = tf.nn.conv3d( z, w_inv, [1, 1, 1, 1, 1], 'SAME', data_format='NDHWC') logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2]*shape[3]) return z, logdet
def revnet3d_step(name, z, logdet, hps, reverse): with tf.variable_scope(name): shape = Z.int_shape(z) n_z = shape[4] assert n_z % 2 == 0 if not reverse: z, logdet = Z.actnorm("actnorm", z, logdet=logdet) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv("invconv", z, logdet) else: raise Exception() z1 = z[:, :, :, :, :n_z // 2] z2 = z[:, :, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 += f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, :, 1::2] + 2.) z2 += shift z2 *= scale logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3, 4]) else: raise Exception() z = tf.concat([z1, z2], 4) else: z1 = z[:, :, :, :, :n_z // 2] z2 = z[:, :, :, :, n_z // 2:] if hps.flow_coupling == 0: z2 -= f("f1", z1, hps.width) elif hps.flow_coupling == 1: h = f("f1", z1, hps.width, n_z) shift = h[:, :, :, :, 0::2] # scale = tf.exp(h[:, :, :, 1::2]) scale = tf.nn.sigmoid(h[:, :, :, :, 1::2] + 2.) z2 /= scale z2 -= shift logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3, 4]) else: raise Exception() z = tf.concat([z1, z2], 4) if hps.flow_permutation == 0: z = Z.reverse_features("reverse", z, reverse=True) elif hps.flow_permutation == 1: z = Z.shuffle_features("shuffle", z, reverse=True) elif hps.flow_permutation == 2: z, logdet = invertible_1x1_conv( "invconv", z, logdet, reverse=True) else: raise Exception() z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True) return z, logdet