Exemple #1
0
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception('Please specify path to data directory in single_length_train.py!')

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
                                                            n_examples=FLAGS.MAX_N_EXAMPLES)

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(charmap,
                                                                                                            real_inputs_discrete,
                                                                                                            seq_length)
    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length)
            model_and_data_serialization.optimistic_restore(session,
                                                            latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

        gen = inf_train_gen(lines, charmap)

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data = next(gen)
                _disc_cost, _, real_scores = session.run(
                    [disc_cost, disc_train_op, disc_real],
                    feed_dict={real_inputs_discrete: _data}
                )

            # Train G
            for i in range(GEN_ITERS):
                _data = next(gen)
                _ = session.run(gen_train_op, feed_dict={real_inputs_discrete: _data})

            print("iteration %s/%s"%(iteration, iterations))
            print("disc cost %f"%_disc_cost)

            # Summaries
            if iteration % 100 == 99:
                _data = next(gen)
                summary_str = session.run(
                    merged,
                    feed_dict={real_inputs_discrete: _data}
                )

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
                                                                                                              fake_inputs,
                                                                                                              disc_fake,
                                                                                                              gen,
                                                                                                              real_inputs_discrete,
                                                                                                              feed_gt=True)

                log_samples(fake_samples, fake_scores, iteration, seq_length, "train")
                log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length,
                             "gt")
                test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
                                                                                      inference_op,
                                                                                      disc_on_inference,
                                                                                      gen,
                                                                                      real_inputs_discrete,
                                                                                      feed_gt=False)
                # disc_on_inference, inference_op
                log_samples(test_samples, fake_scores, iteration, seq_length, "test")


            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1:
                saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")

        saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
Exemple #2
0
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception('Please specify path to data directory in single_length_train.py!')

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False, n_examples=FLAGS.MAX_N_EXAMPLES)

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op, other_ops = define_objective(charmap,real_inputs_discrete, seq_length,
        gan_type=FLAGS.GAN_TYPE, rnn_cell=RNN_CELL)


    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(
        disc_cost, gen_cost, global_step, FLAGS.DISC_LR, FLAGS.GEN_LR)

    saver = tf.train.Saver(tf.trainable_variables())

    # Use JIT XLA compilation to speed up calculations
    config=tf.ConfigProto(
        log_device_placement=False, allow_soft_placement=True)
    config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    with tf.Session(config=config) as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length)
            model_and_data_serialization.optimistic_restore(session,
                latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

        gen = inf_train_gen(lines, charmap)
        _gen_cost_list = []
        _disc_cost_list = []
        _step_time_list = []

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data = next(gen)

                if FLAGS.GAN_TYPE.lower() == "fgan":
                    _disc_cost, _, real_scores, _ = session.run(
                    [disc_cost, disc_train_op, disc_real,
                        other_ops["alpha_optimizer_op"]],
                    feed_dict={real_inputs_discrete: _data}
                    )

                elif FLAGS.GAN_TYPE.lower() == "wgan":
                    _disc_cost, _, real_scores = session.run(
                    [disc_cost, disc_train_op, disc_real],
                    feed_dict={real_inputs_discrete: _data}
                    )

                else:
                    raise ValueError(
                        "Appropriate gan type not selected: {}".format(FLAGS.GAN_TYPE))
                _disc_cost_list.append(_disc_cost)



            # Train G
            for i in range(GEN_ITERS):
                _data = next(gen)
                # in Fisher GAN, paper measures convergence by gen_cost instead of disc_cost
                # To measure convergence, gen_cost should start at a positive number and decrease
                # to zero. The lower, the better.
                _gen_cost, _ = session.run([gen_cost, gen_train_op], feed_dict={real_inputs_discrete: _data})
                _gen_cost_list.append(_gen_cost)

            _step_time_list.append(time.time() - start_time)

            if FLAGS.PRINT_EVERY_STEP:
                print("iteration %s/%s"%(iteration, iterations))
                print("disc cost {}"%_disc_cost)
                print("gen cost {}".format(_gen_cost))
                print("total step time {}".format(time.time() - start_time))


            # Summaries
            if iteration % FLAGS.PRINT_ITERATION == FLAGS.PRINT_ITERATION-1:
                _data = next(gen)
                summary_str = session.run(
                    merged,
                    feed_dict={real_inputs_discrete: _data}
                )

                tf.logging.warn("iteration %s/%s"%(iteration, iterations))
                tf.logging.warn("disc cost {} gen cost {} average step time {}".format(
                    np.mean(_disc_cost_list), np.mean(_gen_cost_list), np.mean(_step_time_list)))
                _gen_cost_list, _disc_cost_list, _step_time_list = [], [], []

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, fake_inputs, disc_fake, gen, real_inputs_discrete,feed_gt=True)

                log_samples(fake_samples, fake_scores, iteration, seq_length, "train")
                log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length,
                             "gt")
                test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session,
                  inv_charmap,
                  inference_op,
                  disc_on_inference,
                  gen,
                  real_inputs_discrete,
                  feed_gt=False)
                # disc_on_inference, inference_op
                log_samples(test_samples, fake_scores, iteration, seq_length, "test")


            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1:
                saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")

        saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
def evaluate(EVAL_FLAGS):

    # load dict & params
    _, charmap, inv_charmap = model_and_data_serialization.load_dataset(
        seq_length=32, b_lines=False)
    seq_length = EVAL_FLAGS.seq_len
    N = EVAL_FLAGS.num_samples

    ckp_list, config_list = get_models_list()

    print("ALL MODELS: %0s" % ckp_list)

    for ckp_path, config_path in zip(ckp_list, config_list):

        model_name = "%0s_%0s" % (ckp_path.split('/')[2],
                                  ckp_path.split('/')[4])

        #filter by model name
        if not model_name.startswith(EVAL_FLAGS.prefix_filter):
            print("NOT EVALUATING [%0s]" % model_name)
            continue

        tf.reset_default_graph()

        print("EVALUATING [%0s]" % model_name)
        print("restoring config:")
        FLAGS.DISC_STATE_SIZE = restore_param_from_config(
            config_path, 'DISC_STATE_SIZE')
        FLAGS.GEN_STATE_SIZE = restore_param_from_config(
            config_path, 'GEN_STATE_SIZE')
        print("DISC_STATE_SIZE [%0d]" % FLAGS.DISC_STATE_SIZE)
        print("GEN_STATE_SIZE [%0d]" % FLAGS.GEN_STATE_SIZE)

        lines, _, _ = model_and_data_serialization.load_dataset(
            seq_length=seq_length,
            b_charmap=False,
            b_inv_charmap=False,
            n_examples=FLAGS.MAX_N_EXAMPLES,
            dataset='heldout')

        real_inputs_discrete = tf.placeholder(tf.int32,
                                              shape=[BATCH_SIZE, seq_length])

        global_step = tf.Variable(0, trainable=False)
        disc_cost, gen_cost, train_pred, train_pred_for_eval, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(
            charmap, real_inputs_discrete, seq_length)

        # train_pred -> run session
        train_pred_all = np.zeros(
            [N, BATCH_SIZE, seq_length, train_pred.shape[2]], dtype=np.float32)

        sess = tf.Session()
        sess.run(tf.initialize_all_variables())

        # load checkpoints
        # internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(0)
        internal_checkpoint_dir = ckp_path
        model_and_data_serialization.optimistic_restore(
            sess, latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
        restore_config.set_restore_dir(
            load_from_curr_session=True
        )  # global param, always load from curr session after finishing the first seq

        # create samples
        print('start creating samples..')
        test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(
            sess,
            inv_charmap,
            inference_op,
            disc_on_inference,
            None,
            real_inputs_discrete,
            feed_gt=False)

        log_samples(test_samples, fake_scores, 666, seq_length, "test")

        print('starting BPC eval...')
        BPC_list = []

        if EVAL_FLAGS.short_run:
            end_line = BATCH_SIZE * 100  # 100 batches
        else:
            end_line = len(lines) - BATCH_SIZE + 1  # all data

        for start_line in range(0, end_line, BATCH_SIZE):
            t0 = time.time()
            _data = np.array([[charmap[c] for c in l]
                              for l in lines[start_line:start_line +
                                             BATCH_SIZE]])

            # rand N noise vectors and for each one - calculate train_pred.
            for i in range(N):
                train_pred_i = sess.run(
                    train_pred_for_eval,
                    feed_dict={real_inputs_discrete: _data})
                train_pred_all[i, :, :, :] = train_pred_i

            # take average on each time step (first dimension)
            train_pred_average = np.mean(train_pred_all, axis=0)

            # compute BPC (char-based perplexity)
            train_pred_average_2d = train_pred_average.reshape([
                train_pred_average.shape[0] * train_pred_average.shape[1],
                train_pred_average.shape[2]
            ])
            real_data = _data.reshape([_data.shape[0] * _data.shape[1]])

            BPC = 0

            epsilon = 1e-20
            for i in range(real_data.shape[0]):
                BPC -= np.log2(train_pred_average_2d[i, real_data[i]] +
                               epsilon)

            BPC /= real_data.shape[0]
            print("BPC of start_line %d/%d = %.2f" %
                  (start_line, len(lines), BPC))
            print("t_iter = %.2f" % (time.time() - t0))
            BPC_list.append(BPC)
            np.save('BPC_list_temp.npy', BPC_list)

        BPC_final = np.mean(BPC_list)
        print("[%0s]BPC_final = %.2f\n" % (ckp_path, BPC_final))
        np.save("%0s_BPC_list.npy" % model_name, BPC_list)
        np.save("%0s_BPC_final.npy" % model_name, BPC_final)