def dcgan_discrim(x_hat_batch, pilot, hparams): assert hparams.batch_size in [1, 64], 'batch size should be either 64 or 1' x_hat_image = tf.reshape(x_hat_batch, [-1, 64, 16, 2]) all_zeros = tf.zeros([64, 64, 16, 2]) discrim_input = all_zeros + x_hat_image yb = tf.reshape( pilot, [hparams.batch_size, 1, 1, hparams.pilot_dim]) # conditional discrim_input = conv_cond_concat(discrim_input, yb) # conditional model_hparams = celebA_dcgan_model_def.Hparams() prob, _ = celebA_dcgan_model_def.discriminator(model_hparams, discrim_input, train=False, reuse=False) prob = tf.reshape(prob, [-1]) prob = prob[:hparams.batch_size] restore_vars = celebA_dcgan_model_def.gen_restore_vars() restore_dict = { var.op.name: var for var in tf.global_variables() if var.op.name in restore_vars } restore_path = tf.train.latest_checkpoint(hparams.pretrained_model_dir) return prob, restore_dict, restore_path
def dcgan_discrim(x_hat_batch, hparams): assert hparams.batch_size in [1, 64], 'batch size should be either 64 or 1' x_hat_image = tf.reshape(x_hat_batch, [-1, 64, 64, 3]) all_zeros = tf.zeros([64, 64, 64, 3]) discrim_input = all_zeros + x_hat_image model_hparams = celebA_dcgan_model_def.Hparams() prob, _ = celebA_dcgan_model_def.discriminator(model_hparams, discrim_input, train=False, reuse=False) prob = tf.reshape(prob, [-1]) prob = prob[:hparams.batch_size] restore_vars = celebA_dcgan_model_def.gen_restore_vars() restore_dict = { var.op.name: var for var in tf.global_variables() if var.op.name in restore_vars } restore_path = tf.train.latest_checkpoint(hparams.pretrained_model_dir) tf.get_variable_scope().reuse_variables() return prob, restore_dict, restore_path