Пример #1
0
def sampling_model(sess, model, gen_model, data_set, step, seq_len, subset_str=''):
    """Returns the average weighted cost, reconstruction cost and KL cost."""
    sketch_size, photo_size = data_set.sketch_size, data_set.image_size

    image_index = np.random.randint(0, photo_size)
    sketch_index = data_set.get_corr_sketch_id(image_index)
    gt_strokes = data_set.sketch_strokes[sketch_index]

    image_feat, rnn_enc_seq_len = data_set.get_input_image(image_index)
    sample_strokes, m = sample(sess, model, image_feat, seq_len=seq_len, rnn_enc_seq_len=rnn_enc_seq_len)
    strokes = utils.to_normal_strokes(sample_strokes)
    svg_gen_sketch = os.path.join(FLAGS.img_dir, '%s/%s/gensketch_for_photo%d_step%d.svg' % (data_set.dataset, subset_str, image_index, step))
    utils.draw_strokes(strokes, svg_filename=svg_gen_sketch)
    svg_gt_sketch = os.path.join(FLAGS.img_dir, '%s/%s/gt_sketch%d_for_photo%d.svg' % (data_set.dataset, subset_str, sketch_index, image_index))
    utils.draw_strokes(gt_strokes, svg_filename=svg_gt_sketch)
    input_sketch = data_set.pad_single_sketch(image_index)
    feed = {gen_model.input_sketch: input_sketch, gen_model.input_photo: image_feat, gen_model.sequence_lengths: [seq_len]}
    gen_photo = sess.run(gen_model.gen_photo, feed)
    gen_photo_file = os.path.join(FLAGS.img_dir, '%s/%s/gen_photo%d_step%d.png' % (data_set.dataset, subset_str, image_index, step))
    cv2.imwrite(gen_photo_file, cv2.cvtColor(gen_photo[0, ::].astype(np.uint8), cv2.COLOR_RGB2BGR))
    gt_photo = os.path.join(FLAGS.img_dir, '%s/%s/gt_photo%d.png' % (data_set.dataset, subset_str, image_index))
    if len(image_feat[0].shape) == 2:
        cv2.imwrite(gt_photo, image_feat[0])
    else:
        cv2.imwrite(gt_photo, cv2.cvtColor(image_feat[0].astype(np.uint8), cv2.COLOR_RGB2BGR))
Пример #2
0
def sampling_model_eval(sess, model, gen_model, data_set, seq_len):
    """Returns the average weighted cost, reconstruction cost and KL cost."""
    sketch_size, photo_size = data_set.sketch_size, data_set.image_size

    folders_to_create = ['gen_test', 'gen_test_png', 'gt_test', 'gt_test_png', 'gt_test_photo', 'gt_test_sketch_image',
                         'gen_test_s', 'gen_test_s_png', 'gen_test_inter', 'gen_test_inter_png', 'gen_test_inter_sep',
                         'gen_test_inter_sep_png', 'gen_photo', 'gen_test_inter_with_photo', 'recon_test',
                         'recon_test_png', 'recon_photo']
    for folder_to_create in folders_to_create:
        folder_path = os.path.join(FLAGS.img_dir, '%s/%s' % (data_set.dataset, folder_to_create))
        if not os.path.exists(folder_path):
            os.mkdir(folder_path)

    for image_index in range(photo_size):

        sys.stdout.write('\x1b[2K\r>> Sampling test set, [%d/%d]' % (image_index + 1, photo_size))
        sys.stdout.flush()

        image_feat, rnn_enc_seq_len = data_set.get_input_image(image_index)
        sample_strokes, m = sample(sess, model, image_feat, seq_len=seq_len, rnn_enc_seq_len=rnn_enc_seq_len)
        strokes = utils.to_normal_strokes(sample_strokes)
        svg_gen_sketch = os.path.join(FLAGS.img_dir, '%s/gen_test/gen_sketch%d.svg' % (data_set.dataset, image_index))
        png_gen_sketch = os.path.join(FLAGS.img_dir, '%s/gen_test_png/gen_sketch%d.png' % (data_set.dataset, image_index))
        utils.sv_svg_png_from_strokes(strokes, svg_filename=svg_gen_sketch, png_filename=png_gen_sketch)

    print("\nSampling finished")
Пример #3
0
def decode(session, sample_model, max_seq_len, z_input=None, temperature=0.1):
    z = None
    if z_input is not None:
        z = [z_input]

    sample_strokes, m = sketch_rnn_model.sample(session, sample_model,
                                                seq_len=max_seq_len, temperature=temperature, z=z)
    strokes = utils.to_normal_strokes(sample_strokes)  # sample_strokes in stroke-5 format, strokes in stroke-3 format
    return strokes
Пример #4
0
def evaluate(model, metrics, test_loader, vocab_desc, vocab_api, f_eval,
             repeat):
    ivocab_api = {v: k for k, v in vocab_api.items()}
    ivocab_desc = {v: k for k, v in vocab_desc.items()}

    recall_bleus, prec_bleus = [], []
    local_t = 0
    for descs, apiseqs, desc_lens, api_lens in tqdm(test_loader):

        if local_t > 2000:
            break

        desc_str = indexes2sent(descs[0].numpy(), vocab_desc)

        descs, desc_lens = gVar(descs), gVar(desc_lens)
        sample_words, sample_lens = model.sample(descs, desc_lens, repeat)
        # nparray: [repeat x seq_len]
        pred_sents, _ = indexes2sent(sample_words, vocab_api)
        pred_tokens = [sent.split(' ') for sent in pred_sents]
        ref_str, _ = indexes2sent(apiseqs[0].numpy(), vocab_api,
                                  vocab_api["<s>"])
        ref_tokens = ref_str.split(' ')

        max_bleu, avg_bleu = metrics.sim_bleu(pred_tokens, ref_tokens)
        recall_bleus.append(max_bleu)
        prec_bleus.append(avg_bleu)

        local_t += 1
        f_eval.write("Batch %d \n" % (local_t))  # print the context
        f_eval.write("Query: {} \n".format(desc_str))
        f_eval.write("Target >> %s\n" %
                     (ref_str.replace(" ' ", "'")))  # print the true outputs
        for r_id, pred_sent in enumerate(pred_sents):
            f_eval.write("Sample %d >> %s\n" %
                         (r_id, pred_sent.replace(" ' ", "'")))
        f_eval.write("\n")

    recall_bleu = float(np.mean(recall_bleus))
    prec_bleu = float(np.mean(prec_bleus))
    f1 = 2 * (prec_bleu * recall_bleu) / (prec_bleu + recall_bleu + 10e-12)

    report = "Avg recall BLEU %f, avg precision BLEU %f, F1 %f" % (
        recall_bleu, prec_bleu, f1)
    print(report)
    f_eval.write(report + "\n")
    print("Done testing")

    return recall_bleu, prec_bleu
Пример #5
0
def test_flip_dist(sess, model, sample_model, hps, feed, s, a, p, n):
    tf_images = sess.run(model.tf_images, feed)
    import matplotlib.pyplot as plt
    if hps.vae_type == 'p2s':
        orig_sketch_png = np.array(p[0, ::], np.uint8)
    else:
        orig_sketch_png = np.array(
            np.stack([n[0, ::], n[0, ::], n[0, ::]], axis=-1), np.uint8)
    if hps.basenet in ['inceptionv1', 'inceptionv3']:
        dist_sketch_png = np.array((tf_images[0, ::] + 1) * 255 / 2, np.uint8)
    elif hps.basenet in ['sketchanet']:
        dist_sketch_png = np.array(tf_images[0, ::] + 250.42, np.uint8)
    else:
        raise Exception('basenet type error')
    import cv2
    cv2.imwrite('./orig_image.png', orig_sketch_png)
    cv2.imwrite('./dist_image.png', dist_sketch_png)

    import pdb
    pdb.set_trace()

    flipr_stroke = a[0, :, :4]
    flipr_stroke[:, 2] = a[0, :, 3]
    flipr_stroke[:, 2] += a[0, :, 4]
    flipr_stroke[:, 3] = a[0, :, 4]
    sv_svg_png_from_strokes(flipr_stroke,
                            svg_filename='./flipr_stroke.svg',
                            png_filename='./flipr_stroke.png')
    gen_strokes = sess.run(model.gen_strokes, feed)
    sv_svg_png_from_strokes(gen_strokes[0, :s[0], :],
                            svg_filename='./gen_stroke.svg',
                            png_filename='./gen_stroke.png')
    tar_strokes = sess.run(model.target_strokes, feed)
    sv_svg_png_from_strokes(tar_strokes[0, :s[0], :],
                            svg_filename='./tar_stroke.svg',
                            png_filename='./tar_stroke.png')
    # sampel via the eval model
    from model import sample
    sample_strokes, m = sample(sess,
                               sample_model,
                               p[:1, ::],
                               seq_len=hps.max_seq_len,
                               greedy_mode=True)
    sample_strokes_normal = to_normal_strokes(sample_strokes)
    sv_svg_png_from_strokes(sample_strokes_normal,
                            svg_filename='./sample_stroke.svg',
                            png_filename='./sample_stroke.png')
Пример #6
0
'''
MIT License
Copyright (c) 2017 Mat Leonard
'''

import argparse

from model import CharRNN, load_model, sample

parser = argparse.ArgumentParser(
                        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('checkpoint', type=str, default=None,
                    help='initialize network from checkpoint')
parser.add_argument('--gpu', action='store_true', default=False,
                    help='run the network on the GPU')
parser.add_argument('--num_samples', type=int, default=200,
                    help='number of samples for generating text')
parser.add_argument('--prime', type=str, default='From afar',
                    help='prime the network with characters for sampling')
parser.add_argument('--top_k', type=int, default=10,
                    help='sample from top K character probabilities')


args = parser.parse_args()

net = load_model(args.checkpoint)

print(sample(net, args.num_samples, cuda=args.gpu, top_k=args.top_k, prime=args.prime))
Пример #7
0
    sentence = ''
    diversity = 0.25
    sonnet = ''
    # want to generate the poem backwards
    for i in range(13, -1, -1):
        line = generated
        sentence = (sentence + '#\n' + generated[::-1])[-25:]
        syls = len(line.split(' '))
        while True:
            # convert sentence to data format for model
            x = np.zeros((1, seq_length, len(chars)))
            for t, char in enumerate(sentence):
                x[0, t, char_to_int[char]] = 1.

            preds = model.predict(x, verbose=0)[0]
            next_index = sample(preds, diversity)
            next_char = int_to_char[next_index]

            # ignore special characters
            if (next_char == '\n') or (next_char == '#'):
                sentence = sentence[1:] + next_char
                continue

            # ideally I would have created a syllable counter, but we're just
# using a word count for now
            if (next_char == ' '):
                syls += 1
                if syls >= 10:
                    break

            line = next_char + line
Пример #8
0
                    default=10000,
                    help='sample size for Monte Carlo model')
args = parser.parse_args()

with open(args.passwordfile, 'rt') as f:
    training = [w.strip('\r\n') for w in f]

models = {
    '{}-gram'.format(i): ngram_chain.NGramModel(training, i)
    for i in range(args.min_ngram, args.max_ngram + 1)
}
models['Backoff'] = backoff.BackoffModel(training, 10)
models['PCFG'] = pcfg.PCFG(training)

samples = {
    name: list(model.sample(args.samplesize))
    for name, model in models.items()
}

estimators = {
    name: model.PosEstimator(sample)
    for name, sample in samples.items()
}
modelnames = sorted(models)

writer = csv.writer(sys.stdout)
writer.writerow(['password'] + modelnames)

for password in sys.stdin:
    password = password.strip('\r\n')
    estimations = [