コード例 #1
0
def sampling_model(sess, model, gen_model, data_set, step, seq_len, subset_str=''):
    """Returns the average weighted cost, reconstruction cost and KL cost."""
    sketch_size, photo_size = data_set.sketch_size, data_set.image_size

    image_index = np.random.randint(0, photo_size)
    sketch_index = data_set.get_corr_sketch_id(image_index)
    gt_strokes = data_set.sketch_strokes[sketch_index]

    image_feat, rnn_enc_seq_len = data_set.get_input_image(image_index)
    sample_strokes, m = sample(sess, model, image_feat, seq_len=seq_len, rnn_enc_seq_len=rnn_enc_seq_len)
    strokes = utils.to_normal_strokes(sample_strokes)
    svg_gen_sketch = os.path.join(FLAGS.img_dir, '%s/%s/gensketch_for_photo%d_step%d.svg' % (data_set.dataset, subset_str, image_index, step))
    utils.draw_strokes(strokes, svg_filename=svg_gen_sketch)
    svg_gt_sketch = os.path.join(FLAGS.img_dir, '%s/%s/gt_sketch%d_for_photo%d.svg' % (data_set.dataset, subset_str, sketch_index, image_index))
    utils.draw_strokes(gt_strokes, svg_filename=svg_gt_sketch)
    input_sketch = data_set.pad_single_sketch(image_index)
    feed = {gen_model.input_sketch: input_sketch, gen_model.input_photo: image_feat, gen_model.sequence_lengths: [seq_len]}
    gen_photo = sess.run(gen_model.gen_photo, feed)
    gen_photo_file = os.path.join(FLAGS.img_dir, '%s/%s/gen_photo%d_step%d.png' % (data_set.dataset, subset_str, image_index, step))
    cv2.imwrite(gen_photo_file, cv2.cvtColor(gen_photo[0, ::].astype(np.uint8), cv2.COLOR_RGB2BGR))
    gt_photo = os.path.join(FLAGS.img_dir, '%s/%s/gt_photo%d.png' % (data_set.dataset, subset_str, image_index))
    if len(image_feat[0].shape) == 2:
        cv2.imwrite(gt_photo, image_feat[0])
    else:
        cv2.imwrite(gt_photo, cv2.cvtColor(image_feat[0].astype(np.uint8), cv2.COLOR_RGB2BGR))
コード例 #2
0
def sampling_model_eval(sess, model, gen_model, data_set, seq_len):
    """Returns the average weighted cost, reconstruction cost and KL cost."""
    sketch_size, photo_size = data_set.sketch_size, data_set.image_size

    folders_to_create = ['gen_test', 'gen_test_png', 'gt_test', 'gt_test_png', 'gt_test_photo', 'gt_test_sketch_image',
                         'gen_test_s', 'gen_test_s_png', 'gen_test_inter', 'gen_test_inter_png', 'gen_test_inter_sep',
                         'gen_test_inter_sep_png', 'gen_photo', 'gen_test_inter_with_photo', 'recon_test',
                         'recon_test_png', 'recon_photo']
    for folder_to_create in folders_to_create:
        folder_path = os.path.join(FLAGS.img_dir, '%s/%s' % (data_set.dataset, folder_to_create))
        if not os.path.exists(folder_path):
            os.mkdir(folder_path)

    for image_index in range(photo_size):

        sys.stdout.write('\x1b[2K\r>> Sampling test set, [%d/%d]' % (image_index + 1, photo_size))
        sys.stdout.flush()

        image_feat, rnn_enc_seq_len = data_set.get_input_image(image_index)
        sample_strokes, m = sample(sess, model, image_feat, seq_len=seq_len, rnn_enc_seq_len=rnn_enc_seq_len)
        strokes = utils.to_normal_strokes(sample_strokes)
        svg_gen_sketch = os.path.join(FLAGS.img_dir, '%s/gen_test/gen_sketch%d.svg' % (data_set.dataset, image_index))
        png_gen_sketch = os.path.join(FLAGS.img_dir, '%s/gen_test_png/gen_sketch%d.png' % (data_set.dataset, image_index))
        utils.sv_svg_png_from_strokes(strokes, svg_filename=svg_gen_sketch, png_filename=png_gen_sketch)

    print("\nSampling finished")
コード例 #3
0
def decode(session, sample_model, max_seq_len, z_input=None, temperature=0.1):
    z = None
    if z_input is not None:
        z = [z_input]

    sample_strokes, m = sketch_rnn_model.sample(session, sample_model,
                                                seq_len=max_seq_len, temperature=temperature, z=z)
    strokes = utils.to_normal_strokes(sample_strokes)  # sample_strokes in stroke-5 format, strokes in stroke-3 format
    return strokes
コード例 #4
0
def batch_rasterize_relative(sketch):
    def to_stroke_list(sketch):
        ## sketch: an `.npz` style sketch from QuickDraw
        sketch = np.vstack((np.array([0, 0, 0]), sketch))
        sketch[:, :2] = np.cumsum(sketch[:, :2], axis=0)

        # range normalization
        xmin, xmax = sketch[:, 0].min(), sketch[:, 0].max()
        ymin, ymax = sketch[:, 1].min(), sketch[:, 1].max()

        sketch[:, 0] = (
            (sketch[:, 0] - xmin) / float(xmax - xmin)) * (255. - 60.) + 30.
        sketch[:, 1] = (
            (sketch[:, 1] - ymin) / float(ymax - ymin)) * (255. - 60.) + 30.
        sketch = sketch.astype(np.int64)

        stroke_list = np.split(sketch[:, :2],
                               np.where(sketch[:, 2])[0] + 1,
                               axis=0)

        if stroke_list[-1].size == 0:
            stroke_list = stroke_list[:-1]

        if len(stroke_list) == 0:
            stroke_list = [sketch[:, :2]]
            # print('error')
        return stroke_list

    batch_redraw = []
    if sketch.shape[-1] == 5:
        for data in sketch:
            # image = rasterize_relative(to_stroke_list(to_normal_strokes(data.cpu().numpy())), canvas)
            image = mydrawPNG_from_list(
                to_stroke_list(to_normal_strokes(data.cpu().numpy())))
            batch_redraw.append(
                torch.from_numpy(np.array(image)).permute(2, 0, 1))
    elif sketch.shape[-1] == 3:
        for data in sketch:
            # image = rasterize_relative(to_stroke_list(data.cpu().numpy()), canvas)
            image = mydrawPNG_from_list(to_stroke_list(data.cpu().numpy()))
            batch_redraw.append(
                torch.from_numpy(np.array(image)).permute(2, 0, 1))

    return torch.stack(batch_redraw).float()
コード例 #5
0
    loss_kl = (logstd.exp() + mean.pow(2) - logstd - 1).sum() / 2.0

    print 'loss-cons:', loss_cons.item(), 'loss-stroke:', loss_stroke.item(
    ), 'loss-kl:', loss_kl.item()
    # # zero the parameter gradients
    # targetVar = nn.utils.rnn.pack_padded_sequence(inputVar[1:,:,:], seq_len, batch_first=False) # first in sequence is S0 used by decoder
    # targetVar = targetVar.data

    # # import ipdb;ipdb.set_trace()

    # loss_cons = criterion_mse(outputVar, targetVar.cuda())
    # loss_kl = (logstd.exp()+mean.pow(2) - logstd - 1).sum()/2.0
    # loss =  loss_cons + loss_kl  #

    # visualize the output
    small_stroke = to_normal_strokes(sample[:, 0, :])
    sample_denorm = dataset.denormalize(small_stroke)
    drawFig(sample_denorm)

    small_stroke = to_normal_strokes(outStroke[:, 0, :])
    sample_denorm = dataset.denormalize(small_stroke)
    small_stroke[-1, -1] = 1
    drawFig(sample_denorm)

    # running_loss_cons += loss_cons.item()
    # running_loss_kl += loss_kl.item()
    # running_loss += loss.item()
    # if count % Showiter == 0:    # print every 20 mini-batches
    #     timestr = time.strftime('%m/%d %H:%M:%S',time.localtime())
    #     print(exp_prefix[0:-1] + ' [%d %s] loss: %.5f, cons_loss: %.5f, kl_loss: %.5f, lr: %f' %
    #     (count , timestr, running_loss / Showiter, running_loss_cons / Showiter,
コード例 #6
0
ファイル: test.py プロジェクト: sqj2015/final_project
def sample(sess, model, pix_h, seq_len=250, temperature=1.0, greedy_mode=False):
    """Samples a sequence from a pre-trained model."""

    def adjust_temp(pi_pdf, temp):
        pi_pdf = np.log(pi_pdf) / temp
        pi_pdf -= pi_pdf.max()
        pi_pdf = np.exp(pi_pdf)
        pi_pdf /= pi_pdf.sum()
        return pi_pdf

    def get_pi_idx(x, pdf, temp=1.0, greedy=False):
        """Samples from a pdf, optionally greedily."""
        if greedy:
            return np.argmax(pdf)
        pdf = adjust_temp(np.copy(pdf), temp)
        accumulate = 0
        for i in range(0, pdf.size):
            accumulate += pdf[i]
            if accumulate >= x:
                return i
        print('Error with sampling ensemble.')
        return -1

    def sample_gaussian_2d(mu1, mu2, s1, s2, rho, temp=1.0, greedy=False):
        if greedy:
            return mu1, mu2
        mean = [mu1, mu2]
        s1 *= temp * temp
        s2 *= temp * temp
        cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]]
        x = np.random.multivariate_normal(mean, cov, 1)
        return x[0][0], x[0][1]

    prev_x = np.zeros((1, 1, 5), dtype=np.float32)
    prev_x[0, 0, 2] = 1  # S0: [0, 0, 1, 0, 0]

    prev_state = sess.run(model.initial_state_p2s, feed_dict={model.pix_h: pix_h})

    strokes = np.zeros((seq_len, 5), dtype=np.float32)
    mixture_params = []

    greedy = greedy_mode
    temp = temperature

    for i in range(seq_len):
        feed = {
            model.input_x: prev_x,
            model.sequence_lengths: [1],
            model.initial_state_p2s: prev_state,
            model.pix_h: pix_h
        }

        gmm_coef, next_state = sess.run([model.gmm_output_p2s, model.final_state_p2s], feed_dict=feed)

        [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = gmm_coef
        # top 6 param: [1, 20], o_pen: [1, 3], next_state: [1, 1024]

        idx = get_pi_idx(random.random(), o_pi[0], temp, greedy)

        idx_eos = get_pi_idx(random.random(), o_pen[0], temp, greedy)

        eos = [0, 0, 0]
        eos[idx_eos] = 1

        next_x1, next_x2 = sample_gaussian_2d(o_mu1[0][idx], o_mu2[0][idx],
                                              o_sigma1[0][idx], o_sigma2[0][idx],
                                              o_corr[0][idx], np.sqrt(temp), greedy)

        strokes[i, :] = [next_x1, next_x2, eos[0], eos[1], eos[2]]

        params = [
            o_pi[0], o_mu1[0], o_mu2[0], o_sigma1[0], o_sigma2[0], o_corr[0],
            o_pen[0]
        ]

        mixture_params.append(params)

        prev_x = np.zeros((1, 1, 5), dtype=np.float32)
        prev_x[0][0] = np.array(
            [next_x1, next_x2, eos[0], eos[1], eos[2]], dtype=np.float32)
        prev_state = next_state

    # strokes in stroke-5 format, strokes in stroke-3 format
    return utils.to_normal_strokes(strokes)