예제 #1
0
def _process_image(image_name,
                   product_image_name,
                   sess,
                   resize_width=192,
                   resize_height=256):
    image_id = image_name[:-4]
    image = imageio.imread(FLAGS.image_dir + image_name)
    prod_image = imageio.imread(FLAGS.image_dir + product_image_name)
    # sorry for the hard coded file path.
    coarse_image = imageio.imread(FLAGS.coarse_result_dir +
                                  "/images/00015000_" + image_name + "_" +
                                  product_image_name + ".png")
    mask_output = imageio.imread(FLAGS.coarse_result_dir +
                                 "/images/00015000_" + image_name + "_" +
                                 product_image_name + "_mask.png")
    image = process_one_image(image, resize_height, resize_width)
    prod_image = process_one_image(prod_image, resize_height, resize_width)
    coarse_image = process_one_image(coarse_image, resize_height, resize_width)
    mask_output = process_one_image(mask_output, resize_height, resize_width,
                                    True)
    # TPS transform
    # Here we use control points to generate
    # We tried to learn the control points, but the network refuses to converge.
    tps_control_points = sio.loadmat(FLAGS.coarse_result_dir +
                                     "/tps/00015000_" + image_name + "_" +
                                     product_image_name + "_tps.mat")
    v = tps_control_points["control_points"]
    nx = v.shape[1]
    ny = v.shape[2]
    v = np.reshape(v, -1)
    v = np.transpose(v.reshape([1, 2, nx * ny]), [0, 2, 1]) * 2 - 1
    p = tf.convert_to_tensor(v, dtype=tf.float32)
    img = tf.reshape(prod_image, [1, 256, 192, 3])

    tps_image = tps_stn(img, nx, ny, p, [256, 192, 3])

    tps_mask = tf.cast(tf.less(tf.reduce_sum(tps_image, -1), 3 * 0.95),
                       tf.float32)

    [image, prod_image, coarse_image, tps_image, mask_output,
     tps_mask] = sess.run(
         [image, prod_image, coarse_image, tps_image, mask_output, tps_mask])

    return image, prod_image, coarse_image, tps_image, mask_output, tps_mask
예제 #2
0
def create_model(product_image, body_seg, skin_seg, pose_map, prod_seg, image,
                 tps_points):
    """Build the model given product image, skin/body segments, pose
     predict the product segmentation.
  """

    with tf.variable_scope("generator") as scope:
        # downsample image and prod_image and input them into the coarse model
        downsample_prod_image = tf.image.resize_images(
            product_image,
            size=[256, 192],
            method=tf.image.ResizeMethod.BILINEAR)
        out_channels = int(prod_seg.get_shape()[-1] + image.get_shape()[-1])
        gen_image_outputs = create_generator(downsample_prod_image, body_seg,
                                             skin_seg, pose_map, out_channels)
        gen_image_outputs = gen_image_outputs[:, :, :,
                                              prod_seg.get_shape()[-1]:]

        gen_image_outputs = tf.image.resize_area(gen_image_outputs,
                                                 (FINAL_HEIGHT, FINAL_WIDTH),
                                                 align_corners=False)
    with tf.variable_scope("stn_generator") as scope:
        prod_image_fg = extract_product_fg(product_image)

        stn_outputs = tps_stn(
            tf.concat([prod_image_fg, product_image], axis=-1), 10, 10,
            tps_points,
            product_image.get_shape()[1:3])

        prod_mask_outputs = stn_outputs[:, :, :, :1]
        stn_image_outputs = stn_outputs[:, :, :, 1:]
    with tf.variable_scope("refine_generator") as scope:
        select_mask = create_refine_generator(stn_image_outputs,
                                              gen_image_outputs)
        # only look at the prod_seg region in select_mask
        select_mask = prod_seg * select_mask
        image_outputs = select_mask * stn_image_outputs + (
            1 - select_mask) * gen_image_outputs
    with tf.name_scope("generator_loss"):
        gen_loss_content_L1 = tf.reduce_mean(tf.abs(image - image_outputs))
        with tf.variable_scope("vgg_19"):
            vgg_real = build_vgg19(image, FLAGS.vgg_model_path)
            vgg_fake = build_vgg19(image_outputs,
                                   FLAGS.vgg_model_path,
                                   reuse=True)
            p1 = compute_error(vgg_real['conv1_2'],
                               vgg_fake['conv1_2']) / 5.3 * 2.5  # 128*128*64
            p2 = compute_error(vgg_real['conv2_2'],
                               vgg_fake['conv2_2']) / 2.7 / 1.2  # 64*64*128
            p3 = compute_error(vgg_real['conv3_2'],
                               vgg_fake['conv3_2']) / 1.35 / 2.3  # 32*32*256
            p4 = compute_error(vgg_real['conv4_2'],
                               vgg_fake['conv4_2']) / 0.67 / 8.2  # 16*16*512
            p5 = compute_error(vgg_real['conv5_2'],
                               vgg_fake['conv5_2']) / 0.16  # 8*8*512
            perceptual_loss = (p3 + p4 +
                               p5) / 3.0 / 128.0  # 128 for normalize to [0.1]

        mask_loss = tf.reduce_mean(select_mask)
        tv_loss = tf.reduce_mean(tf.image.total_variation(select_mask))

        gen_loss = (
            FLAGS.content_l1_weight * gen_loss_content_L1 +
            FLAGS.perceptual_weight * perceptual_loss -
            FLAGS.mask_weight * mask_loss +
            FLAGS.tv_weight * tv_loss  # TV loss
        )

    with tf.name_scope("generator_train"):
        gen_tvars = [
            var for var in tf.trainable_variables()
            if var.name.startswith("refine_generator")
        ]
        gen_optim = tf.train.AdamOptimizer(FLAGS.lr, FLAGS.beta1)
        gen_train = gen_optim.minimize(gen_loss, var_list=gen_tvars)

    global_step = tf.contrib.framework.get_or_create_global_step()
    incr_global_step = tf.assign(global_step, global_step + 1)

    return Model(gen_loss_GAN=gen_loss,
                 gen_loss_content_L1=gen_loss_content_L1,
                 perceptual_loss=perceptual_loss,
                 mask_loss=mask_loss,
                 tv_loss=tv_loss,
                 gen_image_outputs=gen_image_outputs,
                 stn_image_outputs=stn_image_outputs,
                 select_mask=select_mask,
                 image_outputs=image_outputs,
                 prod_mask_outputs=prod_mask_outputs,
                 train=tf.group(incr_global_step, gen_train),
                 global_step=global_step)