def prior(name, y_onehot, hps): with tf.variable_scope(name): n_z = hps.top_shape[-1] h = tf.zeros([tf.shape(y_onehot)[0]]+hps.top_shape[:2]+[2*n_z]) if hps.learntop: h = Z.conv2d_zeros('p', h, 2*n_z) if hps.ycond: h += tf.reshape(Z.linear_zeros("y_emb", y_onehot, 2*n_z), [-1, 1, 1, 2 * n_z]) pz = Z.gaussian_diag(h[:, :, :, :n_z], h[:, :, :, n_z:]) def logp(z1): objective = pz.logp(z1) return objective def sample(eps=None, eps_std=None): if eps is not None: # Already sampled eps. Don't use eps_std z = pz.sample2(eps) elif eps_std is not None: # Sample with given eps_std z = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1])) else: # Sample normally z = pz.sample return z def eps(z1): return pz.get_eps(z1) return logp, sample, eps
def _f_loss(x, y, is_training, reuse=False): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins) objective += -np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective = encoder(z, objective) hps.top_shape = Z.int_shape(z)[1:] # Prior logp, _ = prior("prior", y_onehot, hps) objective += logp(z) # Generative loss nobj = -objective bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int( x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel # Predictive loss if hps.weight_y > 0 and hps.ycond: # Classification loss h_y = tf.reduce_mean(z, axis=[1, 2]) y_logits = Z.linear_zeros("classifier", h_y, hps.n_y) bits_y = tf.nn.softmax_cross_entropy_with_logits_v2( labels=y_onehot, logits=y_logits) / np.log(2.) # Classification accuracy y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32) classification_error = 1 - \ tf.cast(tf.equal(y_predicted, y), tf.float32) else: bits_y = tf.zeros_like(bits_x) classification_error = tf.ones_like(bits_x) return bits_x, bits_y, classification_error
def _f_loss(x, y, is_training, reuse=False): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins) objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective, _ = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _ = prior("prior", y_onehot, hps) objective += logp(z) # Generative loss nobj = - objective bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int( x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel # Predictive loss if hps.weight_y > 0 and hps.ycond: # Classification loss h_y = tf.reduce_mean(z, axis=[1, 2]) y_logits = Z.linear_zeros("classifier", h_y, hps.n_y) bits_y = tf.nn.softmax_cross_entropy_with_logits_v2( labels=y_onehot, logits=y_logits) / np.log(2.) # Classification accuracy y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32) classification_error = 1 - \ tf.cast(tf.equal(y_predicted, y), tf.float32) else: bits_y = tf.zeros_like(bits_x) classification_error = tf.ones_like(bits_x) return bits_x, bits_y, classification_error
def split3d_prior(y, shape, z_prior, level): n_z = shape[-1] h = tf.zeros([shape[0]] + shape[1:4] + [2 * n_z]) mean = h[:, :, :, :, :n_z] logsd = h[:, :, :, :, n_z:] if y is not None: temp_v = Z.linear_zeros("y_emb", y, n_z) mean += tf.reshape(temp_v, [-1, 1, 1, 1, n_z]) if z_prior is not None: mean, logsd = Z.condFun(mean, logsd, z_prior, level) # n_z2 = int(z.get_shape()[3]) # n_z1 = n_z2 # h = Z.conv2d_zeros("conv", z, 2 * n_z1) # # mean = h[:, :, :, 0::2] # logs = h[:, :, :, 1::2] return Z.gaussian_diag(mean, logsd)