def _process_image(image_name, product_image_name, sess, resize_width=192, resize_height=256): image_id = image_name[:-4] image = imageio.imread(FLAGS.image_dir + image_name) prod_image = imageio.imread(FLAGS.image_dir + product_image_name) # sorry for the hard coded file path. coarse_image = imageio.imread(FLAGS.coarse_result_dir + "/images/00015000_" + image_name + "_" + product_image_name + ".png") mask_output = imageio.imread(FLAGS.coarse_result_dir + "/images/00015000_" + image_name + "_" + product_image_name + "_mask.png") image = process_one_image(image, resize_height, resize_width) prod_image = process_one_image(prod_image, resize_height, resize_width) coarse_image = process_one_image(coarse_image, resize_height, resize_width) mask_output = process_one_image(mask_output, resize_height, resize_width, True) # TPS transform # Here we use control points to generate # We tried to learn the control points, but the network refuses to converge. tps_control_points = sio.loadmat(FLAGS.coarse_result_dir + "/tps/00015000_" + image_name + "_" + product_image_name + "_tps.mat") v = tps_control_points["control_points"] nx = v.shape[1] ny = v.shape[2] v = np.reshape(v, -1) v = np.transpose(v.reshape([1, 2, nx * ny]), [0, 2, 1]) * 2 - 1 p = tf.convert_to_tensor(v, dtype=tf.float32) img = tf.reshape(prod_image, [1, 256, 192, 3]) tps_image = tps_stn(img, nx, ny, p, [256, 192, 3]) tps_mask = tf.cast(tf.less(tf.reduce_sum(tps_image, -1), 3 * 0.95), tf.float32) [image, prod_image, coarse_image, tps_image, mask_output, tps_mask] = sess.run( [image, prod_image, coarse_image, tps_image, mask_output, tps_mask]) return image, prod_image, coarse_image, tps_image, mask_output, tps_mask
def create_model(product_image, body_seg, skin_seg, pose_map, prod_seg, image, tps_points): """Build the model given product image, skin/body segments, pose predict the product segmentation. """ with tf.variable_scope("generator") as scope: # downsample image and prod_image and input them into the coarse model downsample_prod_image = tf.image.resize_images( product_image, size=[256, 192], method=tf.image.ResizeMethod.BILINEAR) out_channels = int(prod_seg.get_shape()[-1] + image.get_shape()[-1]) gen_image_outputs = create_generator(downsample_prod_image, body_seg, skin_seg, pose_map, out_channels) gen_image_outputs = gen_image_outputs[:, :, :, prod_seg.get_shape()[-1]:] gen_image_outputs = tf.image.resize_area(gen_image_outputs, (FINAL_HEIGHT, FINAL_WIDTH), align_corners=False) with tf.variable_scope("stn_generator") as scope: prod_image_fg = extract_product_fg(product_image) stn_outputs = tps_stn( tf.concat([prod_image_fg, product_image], axis=-1), 10, 10, tps_points, product_image.get_shape()[1:3]) prod_mask_outputs = stn_outputs[:, :, :, :1] stn_image_outputs = stn_outputs[:, :, :, 1:] with tf.variable_scope("refine_generator") as scope: select_mask = create_refine_generator(stn_image_outputs, gen_image_outputs) # only look at the prod_seg region in select_mask select_mask = prod_seg * select_mask image_outputs = select_mask * stn_image_outputs + ( 1 - select_mask) * gen_image_outputs with tf.name_scope("generator_loss"): gen_loss_content_L1 = tf.reduce_mean(tf.abs(image - image_outputs)) with tf.variable_scope("vgg_19"): vgg_real = build_vgg19(image, FLAGS.vgg_model_path) vgg_fake = build_vgg19(image_outputs, FLAGS.vgg_model_path, reuse=True) p1 = compute_error(vgg_real['conv1_2'], vgg_fake['conv1_2']) / 5.3 * 2.5 # 128*128*64 p2 = compute_error(vgg_real['conv2_2'], vgg_fake['conv2_2']) / 2.7 / 1.2 # 64*64*128 p3 = compute_error(vgg_real['conv3_2'], vgg_fake['conv3_2']) / 1.35 / 2.3 # 32*32*256 p4 = compute_error(vgg_real['conv4_2'], vgg_fake['conv4_2']) / 0.67 / 8.2 # 16*16*512 p5 = compute_error(vgg_real['conv5_2'], vgg_fake['conv5_2']) / 0.16 # 8*8*512 perceptual_loss = (p3 + p4 + p5) / 3.0 / 128.0 # 128 for normalize to [0.1] mask_loss = tf.reduce_mean(select_mask) tv_loss = tf.reduce_mean(tf.image.total_variation(select_mask)) gen_loss = ( FLAGS.content_l1_weight * gen_loss_content_L1 + FLAGS.perceptual_weight * perceptual_loss - FLAGS.mask_weight * mask_loss + FLAGS.tv_weight * tv_loss # TV loss ) with tf.name_scope("generator_train"): gen_tvars = [ var for var in tf.trainable_variables() if var.name.startswith("refine_generator") ] gen_optim = tf.train.AdamOptimizer(FLAGS.lr, FLAGS.beta1) gen_train = gen_optim.minimize(gen_loss, var_list=gen_tvars) global_step = tf.contrib.framework.get_or_create_global_step() incr_global_step = tf.assign(global_step, global_step + 1) return Model(gen_loss_GAN=gen_loss, gen_loss_content_L1=gen_loss_content_L1, perceptual_loss=perceptual_loss, mask_loss=mask_loss, tv_loss=tv_loss, gen_image_outputs=gen_image_outputs, stn_image_outputs=stn_image_outputs, select_mask=select_mask, image_outputs=image_outputs, prod_mask_outputs=prod_mask_outputs, train=tf.group(incr_global_step, gen_train), global_step=global_step)