Пример #1
0
def evaluate(seq_length, N, charmap):

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
                                                            n_examples=FLAGS.MAX_N_EXAMPLES, dataset='heldout')

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, train_pred, train_pred_for_eval, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(charmap,
                                                                                                            real_inputs_discrete,
                                                                                                            seq_length)

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())

    # load checkpoints
    internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(0)
    model_and_data_serialization.optimistic_restore(sess,
                                                    latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
    restore_config.set_restore_dir(
        load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

    cnt = 0
    Linf_norm_values = []
    N_graph_values = []

    print("N = %d\n" % (N))
    # train_pred -> run session
    train_pred_all = np.zeros([N, BATCH_SIZE, seq_length, train_pred.shape[2]],
                            dtype=np.float32)

    for start_line in range(1):
        _data = np.array([[charmap[c] for c in l] for l in lines[start_line:start_line + BATCH_SIZE]])

        # rand N noise vectors and for each one - calculate train_pred.
        for i in range(N):
            train_pred_i = sess.run(train_pred_for_eval, feed_dict={real_inputs_discrete: _data})
            train_pred_all[i,:,:,:] = train_pred_i

    N_run = range(10, N+10, 10)
    train_pred_average_prev = np.mean(train_pred_all[0:N_run[0]], axis=0)

    for N_curr in N_run[1:]:
        N_graph_values.append(N_curr)
        train_pred_average = np.mean(train_pred_all[0:N_curr], axis=0)
        train_pred_average_diff = train_pred_average - train_pred_average_prev
        Linf_norm_diff = np.linalg.norm(train_pred_average_diff, axis=2, ord=np.inf)
        Linf_norm_values.append(np.mean(Linf_norm_diff))
        cnt += 1
        train_pred_average_prev = train_pred_average

    plt.figure()
    plt.plot(N_graph_values, Linf_norm_values)
    plt.title('Approximated error between empiric and real mean (Linf norm)', size=30)
    plt.xlabel('Number of random noise vectors (N)', size=16)
    plt.ylabel('Error (L_inf norm)', size=16)
Пример #2
0
def log_samples(samples, scores, iteration, seq_length, prefix):
    sample_scores = zip(samples, scores)
    sample_scores = sorted(sample_scores, key=lambda sample: sample[1])


    with open(model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + '/{}_samples_{}.txt'.format(
            prefix, iteration),
              'a') as f:
        for s, score in sample_scores:
            s = "".join(s)
            f.write("%s||\t%f\n" % (s, score))
    f.close()
Пример #3
0
def log_samples(samples, scores, iteration, seq_length, prefix):
    '''
    Logs the samples from the generator network.

    if scores, then we also output what the discrinator scored the sample
    else then we just ouptut the inference
    '''
    if scores:
        sample_scores = zip(samples, scores)
        sample_scores = sorted(sample_scores, key=lambda sample: sample[1])

        with open(
                model_and_data_serialization.get_internal_checkpoint_dir(
                    seq_length) + '/{}_{}.txt'.format(prefix, iteration),
                'a') as f:
            for s, score in sample_scores:
                # dont print <naw>
                a = [w for w in s if w != "<naw>"]
                if s[-1] == '<naw>':
                    a.append(s[-1])
                s = " ".join(a)
                f.write("%s \t %f\n" % (s, score))
            f.close()
    else:
        with open(
                model_and_data_serialization.get_internal_checkpoint_dir(
                    seq_length) + '/{}_{}.txt'.format(prefix, iteration),
                'a') as f:
            for s in samples:
                # dont print <naw>
                a = [w for w in s if w != "<naw>"]
                if s[-1] == '<naw>':
                    a.append(s[-1])
                s = " ".join(a)
                f.write("%s \n" % (s))
            f.close()
Пример #4
0
def log_samples(samples, scores, iteration, seq_length, prefix, class_scores = None):

    if class_scores is None:
        class_scores = [None] * len(samples)
    else:
        # softmax
        # fake_class_scores -= np.expand_dims(np.sum(fake_class_scores),axis=1) + 1
        class_scores = np.exp(class_scores) / np.expand_dims(np.sum(np.exp(class_scores),axis=1),axis=1)

    sample_scores = list(zip(samples, scores, class_scores))
    sample_scores = sorted(sample_scores, key=lambda sample: sample[1])

    with open(model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + '/{}_samples_{}.txt'.format(
            prefix, iteration),
              'a',encoding='utf8') as f:
        for s, score, class_score in sample_scores:
            s = "".join(s)
            if class_scores[0] is None:
                f.write("%s||\t\t%0.3f\n" % (s, score))
            else:
                class_score_for_print = ["%0.3f"%score for score in class_score]
                f.write("%s||\t\t%0.3f ||\t%s\n" % (s, score,str(class_score_for_print).replace("'", "")))

    f.close()
Пример #5
0
def run(iterations, seq_length, is_first, charmap, inv_charmap,
        prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception(
            'Please specify path to data directory in single_length_train.py!')

    # lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
    #                                                         n_examples=FLAGS.MAX_N_EXAMPLES)

    # instance data handler
    data_handler = Runtime_data_handler(
        h5_path=FLAGS.H5_PATH,
        json_path=FLAGS.H5_PATH.replace('.h5', '.json'),
        seq_len=seq_length,
        # max_len=self.seq_len,
        # teacher_helping_mode='th_extended',
        use_var_len=False,
        batch_size=BATCH_SIZE,
        use_labels=False)

    #define placeholders
    real_inputs_discrete = tf.placeholder(tf.int32,
                                          shape=[BATCH_SIZE, seq_length])
    real_classes_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE])
    global_step = tf.Variable(0, trainable=False)

    # build graph according to arch
    if FLAGS.ARCH == 'default':
        disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(
            charmap, real_inputs_discrete, seq_length)
        disc_fake_class = None
        disc_real_class = None
        disc_on_inference_class = None

        visualize_text = generate_argmax_samples_and_gt_samples
    elif FLAGS.ARCH == 'class_conditioned':
        disc_cost, gen_cost, fake_inputs, disc_fake, disc_fake_class, disc_real, disc_real_class,\
        disc_on_inference, disc_on_inference_class, inference_op = define_class_objective(charmap,
                                                                                            real_inputs_discrete,
                                                                                            real_classes_discrete,
                                                                                            seq_length,
                                                                                            num_classes=len(data_handler.class_dict))
        visualize_text = generate_argmax_samples_and_gt_samples_class

    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost,
                                                       global_step)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(
                prev_seq_length)
            model_and_data_serialization.optimistic_restore(
                session,
                latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True
            )  # global param, always load from curr session after finishing the first seq

        # gen = inf_train_gen(lines, charmap)
        data_handler.epoch_start(seq_len=seq_length)

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data, _labels = data_handler.get_batch()
                if FLAGS.ARCH == 'class_conditioned':
                    _disc_cost, _, real_scores, real_class_scores = session.run(
                        [disc_cost, disc_train_op, disc_real, disc_real_class],
                        feed_dict={
                            real_inputs_discrete: _data,
                            real_classes_discrete: _labels
                        })
                else:
                    _disc_cost, _, real_scores = session.run(
                        [disc_cost, disc_train_op, disc_real],
                        feed_dict={
                            real_inputs_discrete: _data,
                            real_classes_discrete: _labels
                        })
                    real_class_scores = None

            # Train G
            for i in range(GEN_ITERS):
                # _data = next(gen)
                _data, _labels = data_handler.get_batch()
                _ = session.run(gen_train_op,
                                feed_dict={
                                    real_inputs_discrete: _data,
                                    real_classes_discrete: _labels
                                })

            print("iteration %s/%s" % (iteration, iterations))
            print("disc cost %f" % _disc_cost)

            # Summaries
            if iteration % 1000 == 999:
                # if iteration % 100 == 99:
                _data, _labels = data_handler.get_batch()
                summary_str = session.run(merged,
                                          feed_dict={
                                              real_inputs_discrete: _data,
                                              real_classes_discrete: _labels
                                          })

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores, fake_class_scores = visualize_text(
                    session,
                    inv_charmap,
                    fake_inputs,
                    disc_fake,
                    data_handler,
                    real_inputs_discrete,
                    real_classes_discrete,
                    feed_gt=True,
                    iteration=iteration,
                    seq_length=seq_length,
                    disc_class=disc_fake_class)

                log_samples(fake_samples,
                            fake_scores,
                            iteration,
                            seq_length,
                            "train",
                            class_scores=fake_class_scores)
                log_samples(decode_indices_to_string(_data, inv_charmap),
                            real_scores,
                            iteration,
                            seq_length,
                            "gt",
                            class_scores=real_class_scores)

                # inference
                test_samples, _, fake_scores, fake_class_scores = visualize_text(
                    session,
                    inv_charmap,
                    inference_op,
                    disc_on_inference,
                    data_handler,
                    real_inputs_discrete,
                    real_classes_discrete,
                    feed_gt=False,
                    iteration=iteration,
                    seq_length=seq_length,
                    disc_class=disc_on_inference_class)
                # disc_on_inference, inference_op
                if not FLAGS.ARCH == 'class_conditioned':
                    log_samples(test_samples, fake_scores, iteration,
                                seq_length, "test")

            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY - 1:
                saver.save(
                    session,
                    model_and_data_serialization.get_internal_checkpoint_dir(
                        seq_length) + "/ckp")

        data_handler.epoch_end()

        saver.save(
            session,
            model_and_data_serialization.get_internal_checkpoint_dir(
                seq_length) + "/ckp")
        session.close()
Пример #6
0
def run(iterations, seq_length, is_first, wordmap, inv_wordmap, prev_seq_length, meta_iter, max_meta_iter):
    '''
    Performs a single run of a single meta iter of curriculum training
    also performs reallocate at the end

    iterations = the number of minibatches
    seq_length = the length of the sequences to create
    is_first = if we need to load from ckpt
    wordmap = for getting the naw from
    inv_wordmap = for decoding the samples
    prev_seq_len = for chosing the ckpt to load from
    meta_iter = the round of the seq_length training
    max_meta_iter = used to make sure the tensorboard graphs log using the correct global step
    '''
    # make sure tf does not take up all the GPU mem (model size is not that large so there is unlikely to be fragmentation problems)
    config = tf.ConfigProto(log_device_placement=False)
    config.gpu_options.allow_growth = True

    # first one so copy from the initial location file
    if is_first and os.path.isfile('locations/word-0.locations') or FLAGS.WORDVECS is not None:
        copyfile('locations/word-0.locations', 'locations/word-%d.locations' % seq_length)


    # load the lines from the dataset along with the current wordmap and inv_wordmap
    lines, wordmap, inv_wordmap = load_dataset(seq_length=seq_length, n_examples=FLAGS.MAX_N_EXAMPLES)

    if not os.path.isfile('locations/word-0.locations'):
        copyfile('locations/word-%d.locations' % seq_length, 'locations/word-0.locations')


    # placeholders for the input from the datset
    real_inputs_discrete = [tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length]), \
                              tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])]

    # global step
    global_step = tf.Variable(iterations, trainable=False, name='global-step')

    # indec of <naw> in the map
    naw_r, naw_c = wordmap['<naw>'][0], wordmap['<naw>'][1]

    session = tf.Session(config=config)

    # start the generator to get minibatches from the dataset
    gen = inf_train_gen(lines, wordmap, seq_length)

    # define the network
    #disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op, realloc_op, d_cost_fake, d_cost_real = define_objective(session, wordmap, real_inputs_discrete, seq_length, naw_r, naw_c, gen)
    # get the summaries, optimizers, and saver

    optim_costs, discs, ops, log_costs, embeddings = define_objective(session, wordmap, real_inputs_discrete, seq_length, naw_r, naw_c, gen)

    #embed_config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
    #r_embedding = embed_config.embeddings.add()
    #r_embedding.tensor_name = embeddings.name
    #r_embedding.metadata_path = '../../loations/metadata.tsv'

    disc_cost, gen_cost = optim_costs
    disc_fake, disc_real, disc_on_inference = discs
    fake_inputs, inference_op, realloc_op = ops
    d_cost_fake, d_cost_real = log_costs

    merged, train_writer = define_summaries(d_cost_real, d_cost_fake, gen_cost, disc_cost)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step)
    saver = tf.train.Saver(tf.trainable_variables())


    with session.as_default():
        # write the graph and init vars
        train_writer.add_graph(session.graph)
        session.run(tf.global_variables_initializer())

        # if not is_first, we need to load from ckpt
        if is_first and PRETRAIN:
            optimistic_restore(session, tf.train.latest_checkpoint('pretrain', 'checkpoint'))
            restore_config.set_restore_dir(load_from_curr_session=True)
        elif not is_first:
            internal_checkpoint_dir = get_internal_checkpoint_dir(prev_seq_length)
            optimistic_restore(session, tf.train.latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
            restore_config.set_restore_dir(load_from_curr_session=True)



        # cool progress bar
        with tqdm(total=iterations, ncols=150) as pbar:

            # train loop
            for iteration in range(iterations):
                # Train critic first
                for _ in range(CRITIC_ITERS):
                    _data = next(gen)
                    _disc_cost, _, _ = session.run([disc_cost, disc_train_op, disc_real], \
                                                   feed_dict={real_inputs_discrete[0]: _data[0], \
                                                              real_inputs_discrete[1]: _data[1]})
                # Train generator
                for _ in range(GEN_ITERS):
                    _data = next(gen)
                    _, _g_cost = session.run([gen_train_op, gen_cost], feed_dict={real_inputs_discrete[0]:_data[0], \
                                                                                  real_inputs_discrete[1]: _data[1]})


                # update progress bat with costs and inc its counter by 1
                pbar.set_description("disc cost %f \t gen cost %f \t"% (_disc_cost, _g_cost))
                pbar.update(1)

                # write sumamries
                if iteration % 100 == 99:
                    _data = next(gen)
                    summary_str = session.run(merged, feed_dict={real_inputs_discrete[0]:_data[0], \
                                                                 real_inputs_discrete[1]:_data[1]})

                    iii = meta_iter*iterations + iteration + ((seq_length-1)*iterations*max_meta_iter)

                    train_writer.add_summary(summary_str, \
                                             global_step=iii)

                    # generate and log ouput from training
                    fake_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_wordmap,
                                                                                          fake_inputs,
                                                                                          disc_fake,
                                                                                          gen,
                                                                                          real_inputs_discrete,
                                                                                          feed_gt=True, method='argmax')
                    log_samples(fake_samples, fake_scores, iii, seq_length, "gen-w-gt")

                    # generate and log output from inference
                    test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_wordmap,
                                                                                        inference_op,
                                                                                          disc_on_inference,
                                                                                          gen,
                                                                                          real_inputs_discrete,
                                                                                          feed_gt=False, method='argmax')
                    log_samples(test_samples, fake_scores, iii, seq_length, "gen-no-gt")


        #*********************************************************************************
        #if seq_length >= 40:
        # copy current location file to next sequence_length as we are going to gen sequences of length seq_len +1 for realloc
        if FLAGS.WORDVECS is None:

            copyfile('locations/word-%d.locations' % (seq_length), 'locations/word-%d.locations' % (seq_length+1))
            if seq_length >= 0:
                # get the lines, note that wordmap and inv_wormap stay the same
                lines, _, _ = load_dataset(seq_length=48, n_examples=FLAGS.MAX_N_EXAMPLES,\
                                           no_write=True)

                # start generator and perform the reallocation
                gen = inf_realloc_gen(lines, wordmap, seq_length+1)
                perf_reallocate(int(1000000), session, inv_wordmap, realloc_op, \
                                gen, seq_length, real_inputs_discrete, naw_r, naw_c)

                #  realloc creates the new location file in seq_length +1 so we move it back if we are not at the last meta_iter
                if meta_iter != max_meta_iter -1:
                    copyfile('locations/word-%d.locations' % (seq_length+1), 'locations/word-%d.locations' % (seq_length))
                    #os.remove('locations/word-%d.locations' % (seq_length+1))
                    os.remove('locations/word-%d.locations.string' % (seq_length+1))


        # save the ckpt and close the session because we need to reset the graph
        saver.save(session, get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
Пример #7
0
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception('Please specify path to data directory in single_length_train.py!')

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False, n_examples=FLAGS.MAX_N_EXAMPLES)

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op, other_ops = define_objective(charmap,real_inputs_discrete, seq_length,
        gan_type=FLAGS.GAN_TYPE, rnn_cell=RNN_CELL)


    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(
        disc_cost, gen_cost, global_step, FLAGS.DISC_LR, FLAGS.GEN_LR)

    saver = tf.train.Saver(tf.trainable_variables())

    # Use JIT XLA compilation to speed up calculations
    config=tf.ConfigProto(
        log_device_placement=False, allow_soft_placement=True)
    config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    with tf.Session(config=config) as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length)
            model_and_data_serialization.optimistic_restore(session,
                latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

        gen = inf_train_gen(lines, charmap)
        _gen_cost_list = []
        _disc_cost_list = []
        _step_time_list = []

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data = next(gen)

                if FLAGS.GAN_TYPE.lower() == "fgan":
                    _disc_cost, _, real_scores, _ = session.run(
                    [disc_cost, disc_train_op, disc_real,
                        other_ops["alpha_optimizer_op"]],
                    feed_dict={real_inputs_discrete: _data}
                    )

                elif FLAGS.GAN_TYPE.lower() == "wgan":
                    _disc_cost, _, real_scores = session.run(
                    [disc_cost, disc_train_op, disc_real],
                    feed_dict={real_inputs_discrete: _data}
                    )

                else:
                    raise ValueError(
                        "Appropriate gan type not selected: {}".format(FLAGS.GAN_TYPE))
                _disc_cost_list.append(_disc_cost)



            # Train G
            for i in range(GEN_ITERS):
                _data = next(gen)
                # in Fisher GAN, paper measures convergence by gen_cost instead of disc_cost
                # To measure convergence, gen_cost should start at a positive number and decrease
                # to zero. The lower, the better.
                _gen_cost, _ = session.run([gen_cost, gen_train_op], feed_dict={real_inputs_discrete: _data})
                _gen_cost_list.append(_gen_cost)

            _step_time_list.append(time.time() - start_time)

            if FLAGS.PRINT_EVERY_STEP:
                print("iteration %s/%s"%(iteration, iterations))
                print("disc cost {}"%_disc_cost)
                print("gen cost {}".format(_gen_cost))
                print("total step time {}".format(time.time() - start_time))


            # Summaries
            if iteration % FLAGS.PRINT_ITERATION == FLAGS.PRINT_ITERATION-1:
                _data = next(gen)
                summary_str = session.run(
                    merged,
                    feed_dict={real_inputs_discrete: _data}
                )

                tf.logging.warn("iteration %s/%s"%(iteration, iterations))
                tf.logging.warn("disc cost {} gen cost {} average step time {}".format(
                    np.mean(_disc_cost_list), np.mean(_gen_cost_list), np.mean(_step_time_list)))
                _gen_cost_list, _disc_cost_list, _step_time_list = [], [], []

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, fake_inputs, disc_fake, gen, real_inputs_discrete,feed_gt=True)

                log_samples(fake_samples, fake_scores, iteration, seq_length, "train")
                log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length,
                             "gt")
                test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session,
                  inv_charmap,
                  inference_op,
                  disc_on_inference,
                  gen,
                  real_inputs_discrete,
                  feed_gt=False)
                # disc_on_inference, inference_op
                log_samples(test_samples, fake_scores, iteration, seq_length, "test")


            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1:
                saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")

        saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
Пример #8
0
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception('Please specify path to data directory in single_length_train.py!')

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
                                                            n_examples=FLAGS.MAX_N_EXAMPLES)

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(charmap,
                                                                                                            real_inputs_discrete,
                                                                                                            seq_length)
    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length)
            model_and_data_serialization.optimistic_restore(session,
                                                            latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

        gen = inf_train_gen(lines, charmap)

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data = next(gen)
                _disc_cost, _, real_scores = session.run(
                    [disc_cost, disc_train_op, disc_real],
                    feed_dict={real_inputs_discrete: _data}
                )

            # Train G
            for i in range(GEN_ITERS):
                _data = next(gen)
                _ = session.run(gen_train_op, feed_dict={real_inputs_discrete: _data})

            print("iteration %s/%s"%(iteration, iterations))
            print("disc cost %f"%_disc_cost)

            # Summaries
            if iteration % 100 == 99:
                _data = next(gen)
                summary_str = session.run(
                    merged,
                    feed_dict={real_inputs_discrete: _data}
                )

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
                                                                                                              fake_inputs,
                                                                                                              disc_fake,
                                                                                                              gen,
                                                                                                              real_inputs_discrete,
                                                                                                              feed_gt=True)

                log_samples(fake_samples, fake_scores, iteration, seq_length, "train")
                log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length,
                             "gt")
                test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
                                                                                      inference_op,
                                                                                      disc_on_inference,
                                                                                      gen,
                                                                                      real_inputs_discrete,
                                                                                      feed_gt=False)
                # disc_on_inference, inference_op
                log_samples(test_samples, fake_scores, iteration, seq_length, "test")


            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1:
                saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")

        saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
Пример #9
0
def evaluate(seq_length, N, charmap, inv_charmap):

    lines, _, _ = model_and_data_serialization.load_dataset(
        seq_length=seq_length,
        b_charmap=False,
        b_inv_charmap=False,
        n_examples=FLAGS.MAX_N_EXAMPLES,
        dataset='heldout')

    real_inputs_discrete = tf.placeholder(tf.int32,
                                          shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, train_pred, train_pred_for_eval, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(
        charmap, real_inputs_discrete, seq_length)

    # train_pred -> run session
    train_pred_all = np.zeros([N, BATCH_SIZE, seq_length, train_pred.shape[2]],
                              dtype=np.float32)

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())

    # load checkpoints
    internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(
        0)
    model_and_data_serialization.optimistic_restore(
        sess, latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
    restore_config.set_restore_dir(
        load_from_curr_session=True
    )  # global param, always load from curr session after finishing the first seq

    BPC_list = []

    for start_line in range(0, len(lines) - BATCH_SIZE + 1, BATCH_SIZE):
        t0 = time.time()
        _data = np.array([[charmap[c] for c in l]
                          for l in lines[start_line:start_line + BATCH_SIZE]])

        # rand N noise vectors and for each one - calculate train_pred.
        for i in range(N):
            train_pred_i = sess.run(train_pred_for_eval,
                                    feed_dict={real_inputs_discrete: _data})
            train_pred_all[i, :, :, :] = train_pred_i

        # take average on each time step (first dimension)
        train_pred_average = np.mean(train_pred_all, axis=0)

        # compute BPC (char-based perplexity)
        train_pred_average_2d = train_pred_average.reshape([
            train_pred_average.shape[0] * train_pred_average.shape[1],
            train_pred_average.shape[2]
        ])
        real_data = _data.reshape([_data.shape[0] * _data.shape[1]])

        BPC = 0

        epsilon = 1e-20
        for i in range(real_data.shape[0]):
            BPC -= np.log2(train_pred_average_2d[i, real_data[i]] + epsilon)

        BPC /= real_data.shape[0]
        print("BPC of start_line %d/%d = %.2f" % (start_line, len(lines), BPC))
        print("t_iter = %.2f" % (time.time() - t0))
        BPC_list.append(BPC)
        np.save('BPC_list_temp.npy', BPC_list)

    BPC_final = np.mean(BPC_list)
    print("BPC_final = %.2f\n" % (BPC_final))
    np.save('BPC_list.npy', BPC_list)
    np.save('BPC_final.npy', BPC_final)