def split2d(name, z, objective=0.): with tf.variable_scope(name): n_z = Z.int_shape(z)[3] z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] pz = split2d_prior(z1) objective += pz.logp(z2) z1 = Z.squeeze2d(z1) return z1, objective
def split2d(name, z, objective=0.): with tf.variable_scope(name): n_z = Z.int_shape(z)[3] z1 = z[:, :, :, :n_z // 2] z2 = z[:, :, :, n_z // 2:] pz = split2d_prior(z1) objective += pz.logp(z2) z1 = Z.squeeze2d(z1) eps = pz.get_eps(z2) return z1, objective, eps
def _f_loss(x, y, is_training, reuse=False): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins) objective += -np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective = encoder(z, objective) hps.top_shape = Z.int_shape(z)[1:] # Prior logp, _ = prior("prior", y_onehot, hps) objective += logp(z) # Generative loss nobj = -objective bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int( x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel # Predictive loss if hps.weight_y > 0 and hps.ycond: # Classification loss h_y = tf.reduce_mean(z, axis=[1, 2]) y_logits = Z.linear_zeros("classifier", h_y, hps.n_y) bits_y = tf.nn.softmax_cross_entropy_with_logits_v2( labels=y_onehot, logits=y_logits) / np.log(2.) # Classification accuracy y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32) classification_error = 1 - \ tf.cast(tf.equal(y_predicted, y), tf.float32) else: bits_y = tf.zeros_like(bits_x) classification_error = tf.ones_like(bits_x) return bits_x, bits_y, classification_error
def _f_loss(x, y, is_training, reuse=False): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins) objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective, _ = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _ = prior("prior", y_onehot, hps) objective += logp(z) # Generative loss nobj = - objective bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int( x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel # Predictive loss if hps.weight_y > 0 and hps.ycond: # Classification loss h_y = tf.reduce_mean(z, axis=[1, 2]) y_logits = Z.linear_zeros("classifier", h_y, hps.n_y) bits_y = tf.nn.softmax_cross_entropy_with_logits_v2( labels=y_onehot, logits=y_logits) / np.log(2.) # Classification accuracy y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32) classification_error = 1 - \ tf.cast(tf.equal(y_predicted, y), tf.float32) else: bits_y = tf.zeros_like(bits_x) classification_error = tf.ones_like(bits_x) return bits_x, bits_y, classification_error
def f_encode(x, y, reuse=True): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins) objective += -np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective, eps = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _eps = prior("prior", y_onehot, hps) objective += logp(z) eps.append(_eps(z)) return eps
def f_encode(x, y, reuse=True): with tf.variable_scope('model', reuse=reuse): y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] z = preprocess(x) z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins) objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) # Encode z = Z.squeeze2d(z, 2) # > 16x16x12 z, objective, eps = encoder(z, objective) # Prior hps.top_shape = Z.int_shape(z)[1:] logp, _, _eps = prior("prior", y_onehot, hps) objective += logp(z) eps.append(_eps(z)) return eps
def _f_loss(x_A, y_A, x_B, y_B, is_training, reuse=False, init=False): with tf.variable_scope('model_A', reuse=reuse): y_onehot_A = tf.cast(tf.one_hot(y_A, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective_A = tf.zeros_like(x_A, dtype='float32')[:, 0, 0, 0] z_A = preprocess(x_A) z_A = z_A + tf.random_uniform(tf.shape(z_A), 0, 1./hps.n_bins) objective_A += - np.log(hps.n_bins) * np.prod(Z.int_shape(z_A)[1:]) # Encode z_A = Z.squeeze2d(z_A, 2) # > 16x16x12 z_A, objective_A, eps_A = encoder_A(z_A, objective_A) # Prior hps.top_shape = Z.int_shape(z_A)[1:] logp_A, _, _eps_A = prior("prior", y_onehot_A, hps) objective_A += logp_A(z_A) # Note that we learn the top layer so need to process z z_A = _eps_A(z_A) eps_A.append(z_A) # Loss of eps and flatten latent code from another model eps_flatten_A = tf.concat( [tf.contrib.layers.flatten(e) for e in eps_A], axis=-1) with tf.variable_scope('model_B', reuse=reuse): y_onehot_B = tf.cast(tf.one_hot(y_B, hps.n_y, 1, 0), 'float32') # Discrete -> Continuous objective_B = tf.zeros_like(x_B, dtype='float32')[:, 0, 0, 0] z_B = preprocess(x_B) z_B = z_B + tf.random_uniform(tf.shape(z_B), 0, 1./hps.n_bins) objective_B += - np.log(hps.n_bins) * np.prod(Z.int_shape(z_B)[1:]) # Encode z_B = Z.squeeze2d(z_B, 2) # > 16x16x12 z_B, objective_B, eps_B = encoder_B(z_B, objective_B) # Prior hps.top_shape = Z.int_shape(z_B)[1:] logp_B, _, _eps_B = prior("prior", y_onehot_B, hps) objective_B += logp_B(z_B) # Note that we learn the top layer so need to process z z_B = _eps_B(z_B) eps_B.append(z_B) # Loss of eps and flatten latent code from another model eps_flatten_B = tf.concat( [tf.contrib.layers.flatten(e) for e in eps_B], axis=-1) code_loss = 0.0 code_shapes = [[16, 16, 6], [8, 8, 12], [4, 4, 48]] if hps.code_loss_type == 'B_all': if not init: """ Decode the code from another model and compute L2 loss at pixel level """ def unflatten_code(fcode, code_shapes): index = 0 code = [] bs = tf.shape(fcode)[0] # bs = hps.local_batch_train for shape in code_shapes: code.append(tf.reshape(fcode[:, index:index+np.prod(shape)], tf.convert_to_tensor([bs] + shape))) index += np.prod(shape) return code code_others = unflatten_code(eps_flatten_A, code_shapes) # code_others[-1] is z, and code_others[:-1] is eps with tf.variable_scope('model_B', reuse=True): _, sample, _ = prior("prior", y_onehot_B, hps) code_last_others = sample(eps=code_others[-1]) code_decoded_others = decoder_B( code_last_others, code_others[:-1]) code_decoded = Z.unsqueeze2d(code_decoded_others, 2) x_B_recon = postprocess(code_decoded) x_B_scaled = 1/255.0 * tf.cast(x_B, tf.float32) x_B_recon_scaled = 1/255.0 * tf.cast(x_B_recon, tf.float32) if hps.code_loss_fn == 'l1': code_loss = tf.reduce_mean(tf.losses.absolute_difference( x_B_scaled, x_B_recon_scaled)) elif hps.code_loss_fn == 'l2': code_loss = tf.reduce_mean(tf.squared_difference( x_B_scaled, x_B_recon_scaled)) else: raise NotImplementedError() elif hps.code_loss_type == 'code_all': code_loss = tf.reduce_mean( tf.squared_difference(eps_flatten_A, eps_flatten_B)) elif hps.code_loss_type == 'code_last': dim = np.prod(code_shapes[-1]) code_loss = tf.reduce_mean(tf.squared_difference( eps_flatten_A[:, -dim:], eps_flatten_B[:, -dim:])) else: raise NotImplementedError() with tf.variable_scope('model_A', reuse=True): # Generative loss nobj_A = - objective_A bits_x_A = nobj_A / (np.log(2.) * int(x_A.get_shape()[1]) * int( x_A.get_shape()[2]) * int(x_A.get_shape()[3])) # bits per subpixel bits_y_A = tf.zeros_like(bits_x_A) classification_error_A = tf.ones_like(bits_x_A) with tf.variable_scope('model_B', reuse=True): # Generative loss nobj_B = - objective_B bits_x_B = nobj_B / (np.log(2.) * int(x_B.get_shape()[1]) * int( x_B.get_shape()[2]) * int(x_B.get_shape()[3])) # bits per subpixel bits_y_B = tf.zeros_like(bits_x_B) classification_error_B = tf.ones_like(bits_x_B) return (bits_x_A, bits_y_A, classification_error_A, eps_flatten_A, bits_x_B, bits_y_B, classification_error_B, eps_flatten_B, code_loss)