Ejemplo n.º 1
0
def run(iterations, seq_length, is_first, charmap, inv_charmap,
        prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception(
            'Please specify path to data directory in single_length_train.py!')

    # lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
    #                                                         n_examples=FLAGS.MAX_N_EXAMPLES)

    # instance data handler
    data_handler = Runtime_data_handler(
        h5_path=FLAGS.H5_PATH,
        json_path=FLAGS.H5_PATH.replace('.h5', '.json'),
        seq_len=seq_length,
        # max_len=self.seq_len,
        # teacher_helping_mode='th_extended',
        use_var_len=False,
        batch_size=BATCH_SIZE,
        use_labels=False)

    #define placeholders
    real_inputs_discrete = tf.placeholder(tf.int32,
                                          shape=[BATCH_SIZE, seq_length])
    real_classes_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE])
    global_step = tf.Variable(0, trainable=False)

    # build graph according to arch
    if FLAGS.ARCH == 'default':
        disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(
            charmap, real_inputs_discrete, seq_length)
        disc_fake_class = None
        disc_real_class = None
        disc_on_inference_class = None

        visualize_text = generate_argmax_samples_and_gt_samples
    elif FLAGS.ARCH == 'class_conditioned':
        disc_cost, gen_cost, fake_inputs, disc_fake, disc_fake_class, disc_real, disc_real_class,\
        disc_on_inference, disc_on_inference_class, inference_op = define_class_objective(charmap,
                                                                                            real_inputs_discrete,
                                                                                            real_classes_discrete,
                                                                                            seq_length,
                                                                                            num_classes=len(data_handler.class_dict))
        visualize_text = generate_argmax_samples_and_gt_samples_class

    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost,
                                                       global_step)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(
                prev_seq_length)
            model_and_data_serialization.optimistic_restore(
                session,
                latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True
            )  # global param, always load from curr session after finishing the first seq

        # gen = inf_train_gen(lines, charmap)
        data_handler.epoch_start(seq_len=seq_length)

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data, _labels = data_handler.get_batch()
                if FLAGS.ARCH == 'class_conditioned':
                    _disc_cost, _, real_scores, real_class_scores = session.run(
                        [disc_cost, disc_train_op, disc_real, disc_real_class],
                        feed_dict={
                            real_inputs_discrete: _data,
                            real_classes_discrete: _labels
                        })
                else:
                    _disc_cost, _, real_scores = session.run(
                        [disc_cost, disc_train_op, disc_real],
                        feed_dict={
                            real_inputs_discrete: _data,
                            real_classes_discrete: _labels
                        })
                    real_class_scores = None

            # Train G
            for i in range(GEN_ITERS):
                # _data = next(gen)
                _data, _labels = data_handler.get_batch()
                _ = session.run(gen_train_op,
                                feed_dict={
                                    real_inputs_discrete: _data,
                                    real_classes_discrete: _labels
                                })

            print("iteration %s/%s" % (iteration, iterations))
            print("disc cost %f" % _disc_cost)

            # Summaries
            if iteration % 1000 == 999:
                # if iteration % 100 == 99:
                _data, _labels = data_handler.get_batch()
                summary_str = session.run(merged,
                                          feed_dict={
                                              real_inputs_discrete: _data,
                                              real_classes_discrete: _labels
                                          })

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores, fake_class_scores = visualize_text(
                    session,
                    inv_charmap,
                    fake_inputs,
                    disc_fake,
                    data_handler,
                    real_inputs_discrete,
                    real_classes_discrete,
                    feed_gt=True,
                    iteration=iteration,
                    seq_length=seq_length,
                    disc_class=disc_fake_class)

                log_samples(fake_samples,
                            fake_scores,
                            iteration,
                            seq_length,
                            "train",
                            class_scores=fake_class_scores)
                log_samples(decode_indices_to_string(_data, inv_charmap),
                            real_scores,
                            iteration,
                            seq_length,
                            "gt",
                            class_scores=real_class_scores)

                # inference
                test_samples, _, fake_scores, fake_class_scores = visualize_text(
                    session,
                    inv_charmap,
                    inference_op,
                    disc_on_inference,
                    data_handler,
                    real_inputs_discrete,
                    real_classes_discrete,
                    feed_gt=False,
                    iteration=iteration,
                    seq_length=seq_length,
                    disc_class=disc_on_inference_class)
                # disc_on_inference, inference_op
                if not FLAGS.ARCH == 'class_conditioned':
                    log_samples(test_samples, fake_scores, iteration,
                                seq_length, "test")

            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY - 1:
                saver.save(
                    session,
                    model_and_data_serialization.get_internal_checkpoint_dir(
                        seq_length) + "/ckp")

        data_handler.epoch_end()

        saver.save(
            session,
            model_and_data_serialization.get_internal_checkpoint_dir(
                seq_length) + "/ckp")
        session.close()
Ejemplo n.º 2
0
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception('Please specify path to data directory in single_length_train.py!')

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False, n_examples=FLAGS.MAX_N_EXAMPLES)

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op, other_ops = define_objective(charmap,real_inputs_discrete, seq_length,
        gan_type=FLAGS.GAN_TYPE, rnn_cell=RNN_CELL)


    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(
        disc_cost, gen_cost, global_step, FLAGS.DISC_LR, FLAGS.GEN_LR)

    saver = tf.train.Saver(tf.trainable_variables())

    # Use JIT XLA compilation to speed up calculations
    config=tf.ConfigProto(
        log_device_placement=False, allow_soft_placement=True)
    config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    with tf.Session(config=config) as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length)
            model_and_data_serialization.optimistic_restore(session,
                latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

        gen = inf_train_gen(lines, charmap)
        _gen_cost_list = []
        _disc_cost_list = []
        _step_time_list = []

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data = next(gen)

                if FLAGS.GAN_TYPE.lower() == "fgan":
                    _disc_cost, _, real_scores, _ = session.run(
                    [disc_cost, disc_train_op, disc_real,
                        other_ops["alpha_optimizer_op"]],
                    feed_dict={real_inputs_discrete: _data}
                    )

                elif FLAGS.GAN_TYPE.lower() == "wgan":
                    _disc_cost, _, real_scores = session.run(
                    [disc_cost, disc_train_op, disc_real],
                    feed_dict={real_inputs_discrete: _data}
                    )

                else:
                    raise ValueError(
                        "Appropriate gan type not selected: {}".format(FLAGS.GAN_TYPE))
                _disc_cost_list.append(_disc_cost)



            # Train G
            for i in range(GEN_ITERS):
                _data = next(gen)
                # in Fisher GAN, paper measures convergence by gen_cost instead of disc_cost
                # To measure convergence, gen_cost should start at a positive number and decrease
                # to zero. The lower, the better.
                _gen_cost, _ = session.run([gen_cost, gen_train_op], feed_dict={real_inputs_discrete: _data})
                _gen_cost_list.append(_gen_cost)

            _step_time_list.append(time.time() - start_time)

            if FLAGS.PRINT_EVERY_STEP:
                print("iteration %s/%s"%(iteration, iterations))
                print("disc cost {}"%_disc_cost)
                print("gen cost {}".format(_gen_cost))
                print("total step time {}".format(time.time() - start_time))


            # Summaries
            if iteration % FLAGS.PRINT_ITERATION == FLAGS.PRINT_ITERATION-1:
                _data = next(gen)
                summary_str = session.run(
                    merged,
                    feed_dict={real_inputs_discrete: _data}
                )

                tf.logging.warn("iteration %s/%s"%(iteration, iterations))
                tf.logging.warn("disc cost {} gen cost {} average step time {}".format(
                    np.mean(_disc_cost_list), np.mean(_gen_cost_list), np.mean(_step_time_list)))
                _gen_cost_list, _disc_cost_list, _step_time_list = [], [], []

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, fake_inputs, disc_fake, gen, real_inputs_discrete,feed_gt=True)

                log_samples(fake_samples, fake_scores, iteration, seq_length, "train")
                log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length,
                             "gt")
                test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session,
                  inv_charmap,
                  inference_op,
                  disc_on_inference,
                  gen,
                  real_inputs_discrete,
                  feed_gt=False)
                # disc_on_inference, inference_op
                log_samples(test_samples, fake_scores, iteration, seq_length, "test")


            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1:
                saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")

        saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
Ejemplo n.º 3
0
def run(iterations, seq_length, is_first, wordmap, inv_wordmap, prev_seq_length, meta_iter, max_meta_iter):
    '''
    Performs a single run of a single meta iter of curriculum training
    also performs reallocate at the end

    iterations = the number of minibatches
    seq_length = the length of the sequences to create
    is_first = if we need to load from ckpt
    wordmap = for getting the naw from
    inv_wordmap = for decoding the samples
    prev_seq_len = for chosing the ckpt to load from
    meta_iter = the round of the seq_length training
    max_meta_iter = used to make sure the tensorboard graphs log using the correct global step
    '''
    # make sure tf does not take up all the GPU mem (model size is not that large so there is unlikely to be fragmentation problems)
    config = tf.ConfigProto(log_device_placement=False)
    config.gpu_options.allow_growth = True

    # first one so copy from the initial location file
    if is_first and os.path.isfile('locations/word-0.locations') or FLAGS.WORDVECS is not None:
        copyfile('locations/word-0.locations', 'locations/word-%d.locations' % seq_length)


    # load the lines from the dataset along with the current wordmap and inv_wordmap
    lines, wordmap, inv_wordmap = load_dataset(seq_length=seq_length, n_examples=FLAGS.MAX_N_EXAMPLES)

    if not os.path.isfile('locations/word-0.locations'):
        copyfile('locations/word-%d.locations' % seq_length, 'locations/word-0.locations')


    # placeholders for the input from the datset
    real_inputs_discrete = [tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length]), \
                              tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])]

    # global step
    global_step = tf.Variable(iterations, trainable=False, name='global-step')

    # indec of <naw> in the map
    naw_r, naw_c = wordmap['<naw>'][0], wordmap['<naw>'][1]

    session = tf.Session(config=config)

    # start the generator to get minibatches from the dataset
    gen = inf_train_gen(lines, wordmap, seq_length)

    # define the network
    #disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op, realloc_op, d_cost_fake, d_cost_real = define_objective(session, wordmap, real_inputs_discrete, seq_length, naw_r, naw_c, gen)
    # get the summaries, optimizers, and saver

    optim_costs, discs, ops, log_costs, embeddings = define_objective(session, wordmap, real_inputs_discrete, seq_length, naw_r, naw_c, gen)

    #embed_config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
    #r_embedding = embed_config.embeddings.add()
    #r_embedding.tensor_name = embeddings.name
    #r_embedding.metadata_path = '../../loations/metadata.tsv'

    disc_cost, gen_cost = optim_costs
    disc_fake, disc_real, disc_on_inference = discs
    fake_inputs, inference_op, realloc_op = ops
    d_cost_fake, d_cost_real = log_costs

    merged, train_writer = define_summaries(d_cost_real, d_cost_fake, gen_cost, disc_cost)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step)
    saver = tf.train.Saver(tf.trainable_variables())


    with session.as_default():
        # write the graph and init vars
        train_writer.add_graph(session.graph)
        session.run(tf.global_variables_initializer())

        # if not is_first, we need to load from ckpt
        if is_first and PRETRAIN:
            optimistic_restore(session, tf.train.latest_checkpoint('pretrain', 'checkpoint'))
            restore_config.set_restore_dir(load_from_curr_session=True)
        elif not is_first:
            internal_checkpoint_dir = get_internal_checkpoint_dir(prev_seq_length)
            optimistic_restore(session, tf.train.latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
            restore_config.set_restore_dir(load_from_curr_session=True)



        # cool progress bar
        with tqdm(total=iterations, ncols=150) as pbar:

            # train loop
            for iteration in range(iterations):
                # Train critic first
                for _ in range(CRITIC_ITERS):
                    _data = next(gen)
                    _disc_cost, _, _ = session.run([disc_cost, disc_train_op, disc_real], \
                                                   feed_dict={real_inputs_discrete[0]: _data[0], \
                                                              real_inputs_discrete[1]: _data[1]})
                # Train generator
                for _ in range(GEN_ITERS):
                    _data = next(gen)
                    _, _g_cost = session.run([gen_train_op, gen_cost], feed_dict={real_inputs_discrete[0]:_data[0], \
                                                                                  real_inputs_discrete[1]: _data[1]})


                # update progress bat with costs and inc its counter by 1
                pbar.set_description("disc cost %f \t gen cost %f \t"% (_disc_cost, _g_cost))
                pbar.update(1)

                # write sumamries
                if iteration % 100 == 99:
                    _data = next(gen)
                    summary_str = session.run(merged, feed_dict={real_inputs_discrete[0]:_data[0], \
                                                                 real_inputs_discrete[1]:_data[1]})

                    iii = meta_iter*iterations + iteration + ((seq_length-1)*iterations*max_meta_iter)

                    train_writer.add_summary(summary_str, \
                                             global_step=iii)

                    # generate and log ouput from training
                    fake_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_wordmap,
                                                                                          fake_inputs,
                                                                                          disc_fake,
                                                                                          gen,
                                                                                          real_inputs_discrete,
                                                                                          feed_gt=True, method='argmax')
                    log_samples(fake_samples, fake_scores, iii, seq_length, "gen-w-gt")

                    # generate and log output from inference
                    test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_wordmap,
                                                                                        inference_op,
                                                                                          disc_on_inference,
                                                                                          gen,
                                                                                          real_inputs_discrete,
                                                                                          feed_gt=False, method='argmax')
                    log_samples(test_samples, fake_scores, iii, seq_length, "gen-no-gt")


        #*********************************************************************************
        #if seq_length >= 40:
        # copy current location file to next sequence_length as we are going to gen sequences of length seq_len +1 for realloc
        if FLAGS.WORDVECS is None:

            copyfile('locations/word-%d.locations' % (seq_length), 'locations/word-%d.locations' % (seq_length+1))
            if seq_length >= 0:
                # get the lines, note that wordmap and inv_wormap stay the same
                lines, _, _ = load_dataset(seq_length=48, n_examples=FLAGS.MAX_N_EXAMPLES,\
                                           no_write=True)

                # start generator and perform the reallocation
                gen = inf_realloc_gen(lines, wordmap, seq_length+1)
                perf_reallocate(int(1000000), session, inv_wordmap, realloc_op, \
                                gen, seq_length, real_inputs_discrete, naw_r, naw_c)

                #  realloc creates the new location file in seq_length +1 so we move it back if we are not at the last meta_iter
                if meta_iter != max_meta_iter -1:
                    copyfile('locations/word-%d.locations' % (seq_length+1), 'locations/word-%d.locations' % (seq_length))
                    #os.remove('locations/word-%d.locations' % (seq_length+1))
                    os.remove('locations/word-%d.locations.string' % (seq_length+1))


        # save the ckpt and close the session because we need to reset the graph
        saver.save(session, get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
Ejemplo n.º 4
0
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception('Please specify path to data directory in single_length_train.py!')

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
                                                            n_examples=FLAGS.MAX_N_EXAMPLES)

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(charmap,
                                                                                                            real_inputs_discrete,
                                                                                                            seq_length)
    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length)
            model_and_data_serialization.optimistic_restore(session,
                                                            latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

        gen = inf_train_gen(lines, charmap)

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data = next(gen)
                _disc_cost, _, real_scores = session.run(
                    [disc_cost, disc_train_op, disc_real],
                    feed_dict={real_inputs_discrete: _data}
                )

            # Train G
            for i in range(GEN_ITERS):
                _data = next(gen)
                _ = session.run(gen_train_op, feed_dict={real_inputs_discrete: _data})

            print("iteration %s/%s"%(iteration, iterations))
            print("disc cost %f"%_disc_cost)

            # Summaries
            if iteration % 100 == 99:
                _data = next(gen)
                summary_str = session.run(
                    merged,
                    feed_dict={real_inputs_discrete: _data}
                )

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
                                                                                                              fake_inputs,
                                                                                                              disc_fake,
                                                                                                              gen,
                                                                                                              real_inputs_discrete,
                                                                                                              feed_gt=True)

                log_samples(fake_samples, fake_scores, iteration, seq_length, "train")
                log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length,
                             "gt")
                test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
                                                                                      inference_op,
                                                                                      disc_on_inference,
                                                                                      gen,
                                                                                      real_inputs_discrete,
                                                                                      feed_gt=False)
                # disc_on_inference, inference_op
                log_samples(test_samples, fake_scores, iteration, seq_length, "test")


            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1:
                saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")

        saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()