def run(iterations, seq_length, is_first, prev_seq_length):
    restore_path = os.path.join((cf.SAVE_CHECKPOINTSLOG), str(prev_seq_length))
    save_path = os.path.join((cf.SAVE_CHECKPOINTSLOG), str(seq_length))
    real_inputs_discrete = tf.placeholder(tf.float32,
                                          shape=[cf.BATCH_SIZE, seq_length, 2])
    y = tf.placeholder(tf.float32, shape=[None, cf.Y_SIZE])
    disc_cost, gen_cost = define_objective(real_inputs_discrete, y, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost)
    saver = tf.train.Saver(tf.trainable_variables())
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = restore_path
            saver.restore(
                session,
                tf.train.latest_checkpoint(internal_checkpoint_dir,
                                           "checkpoint"))

        for iteration in range(iterations):

            # Train critic
            for i in range(cf.CRITIC_ITERS):
                _data, _y = get_data(cf.BATCH_SIZE,
                                     seq_length)  ###get_data需要编写
                _disc_cost, _ = session.run([disc_cost, disc_train_op],
                                            feed_dict={
                                                real_inputs_discrete: _data,
                                                y: _y
                                            })
            # Train G
            for i in range(cf.GEN_ITERS):
                _gen_cost, _ = session.run([gen_cost, gen_train_op],
                                           feed_dict={y: _y})

            if iteration % 100 == 0:
                print("iteration %s/%s" % (iteration, iterations))
                print("disc cost %f" % _disc_cost)
                print("gen cost %f" % _gen_cost)
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            if iteration % cf.SAVE_CHECKPOINTS_EVERY == cf.SAVE_CHECKPOINTS_EVERY - 1:
                saver.save(session,
                           save_path + '\\my_model.ckpt',
                           global_step=iteration)
        session.close()
示例#2
0
def run(iterations, seq_length, is_first, charmap, inv_charmap,
        prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception(
            'Please specify path to data directory in single_length_train.py!')

    # lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
    #                                                         n_examples=FLAGS.MAX_N_EXAMPLES)

    # instance data handler
    data_handler = Runtime_data_handler(
        h5_path=FLAGS.H5_PATH,
        json_path=FLAGS.H5_PATH.replace('.h5', '.json'),
        seq_len=seq_length,
        # max_len=self.seq_len,
        # teacher_helping_mode='th_extended',
        use_var_len=False,
        batch_size=BATCH_SIZE,
        use_labels=False)

    #define placeholders
    real_inputs_discrete = tf.placeholder(tf.int32,
                                          shape=[BATCH_SIZE, seq_length])
    real_classes_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE])
    global_step = tf.Variable(0, trainable=False)

    # build graph according to arch
    if FLAGS.ARCH == 'default':
        disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(
            charmap, real_inputs_discrete, seq_length)
        disc_fake_class = None
        disc_real_class = None
        disc_on_inference_class = None

        visualize_text = generate_argmax_samples_and_gt_samples
    elif FLAGS.ARCH == 'class_conditioned':
        disc_cost, gen_cost, fake_inputs, disc_fake, disc_fake_class, disc_real, disc_real_class,\
        disc_on_inference, disc_on_inference_class, inference_op = define_class_objective(charmap,
                                                                                            real_inputs_discrete,
                                                                                            real_classes_discrete,
                                                                                            seq_length,
                                                                                            num_classes=len(data_handler.class_dict))
        visualize_text = generate_argmax_samples_and_gt_samples_class

    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost,
                                                       global_step)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(
                prev_seq_length)
            model_and_data_serialization.optimistic_restore(
                session,
                latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True
            )  # global param, always load from curr session after finishing the first seq

        # gen = inf_train_gen(lines, charmap)
        data_handler.epoch_start(seq_len=seq_length)

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data, _labels = data_handler.get_batch()
                if FLAGS.ARCH == 'class_conditioned':
                    _disc_cost, _, real_scores, real_class_scores = session.run(
                        [disc_cost, disc_train_op, disc_real, disc_real_class],
                        feed_dict={
                            real_inputs_discrete: _data,
                            real_classes_discrete: _labels
                        })
                else:
                    _disc_cost, _, real_scores = session.run(
                        [disc_cost, disc_train_op, disc_real],
                        feed_dict={
                            real_inputs_discrete: _data,
                            real_classes_discrete: _labels
                        })
                    real_class_scores = None

            # Train G
            for i in range(GEN_ITERS):
                # _data = next(gen)
                _data, _labels = data_handler.get_batch()
                _ = session.run(gen_train_op,
                                feed_dict={
                                    real_inputs_discrete: _data,
                                    real_classes_discrete: _labels
                                })

            print("iteration %s/%s" % (iteration, iterations))
            print("disc cost %f" % _disc_cost)

            # Summaries
            if iteration % 1000 == 999:
                # if iteration % 100 == 99:
                _data, _labels = data_handler.get_batch()
                summary_str = session.run(merged,
                                          feed_dict={
                                              real_inputs_discrete: _data,
                                              real_classes_discrete: _labels
                                          })

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores, fake_class_scores = visualize_text(
                    session,
                    inv_charmap,
                    fake_inputs,
                    disc_fake,
                    data_handler,
                    real_inputs_discrete,
                    real_classes_discrete,
                    feed_gt=True,
                    iteration=iteration,
                    seq_length=seq_length,
                    disc_class=disc_fake_class)

                log_samples(fake_samples,
                            fake_scores,
                            iteration,
                            seq_length,
                            "train",
                            class_scores=fake_class_scores)
                log_samples(decode_indices_to_string(_data, inv_charmap),
                            real_scores,
                            iteration,
                            seq_length,
                            "gt",
                            class_scores=real_class_scores)

                # inference
                test_samples, _, fake_scores, fake_class_scores = visualize_text(
                    session,
                    inv_charmap,
                    inference_op,
                    disc_on_inference,
                    data_handler,
                    real_inputs_discrete,
                    real_classes_discrete,
                    feed_gt=False,
                    iteration=iteration,
                    seq_length=seq_length,
                    disc_class=disc_on_inference_class)
                # disc_on_inference, inference_op
                if not FLAGS.ARCH == 'class_conditioned':
                    log_samples(test_samples, fake_scores, iteration,
                                seq_length, "test")

            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY - 1:
                saver.save(
                    session,
                    model_and_data_serialization.get_internal_checkpoint_dir(
                        seq_length) + "/ckp")

        data_handler.epoch_end()

        saver.save(
            session,
            model_and_data_serialization.get_internal_checkpoint_dir(
                seq_length) + "/ckp")
        session.close()
示例#3
0
def test(beam_length, seq_len_to_test, batch_size, n):

    config = tf.ConfigProto(log_device_placement=False)
    config.gpu_options.allow_growth = True

    seq_length = seq_len_to_test

    _, wordmap, inv_wordmap = load_dataset(seq_length=0, n_examples=FLAGS.MAX_N_EXAMPLES)

    real_inputs_discrete = [tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length]), \
                              tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])]

    # global step
    global_step = tf.Variable(0, trainable=False, name='global-step')

    # indec of <naw> in the map
    naw_r, naw_c = wordmap['<naw>'][0], wordmap['<naw>'][1]

    session = tf.Session(config=config)

    _, _, ops, _, _ = define_objective(session, wordmap, real_inputs_discrete, seq_length, naw_r, naw_c, None)

    fake, inference_op, _ = ops

    with session.as_default():

        optimistic_restore(session, tf.train.latest_checkpoint('pretrain/seq-%d' % n, 'checkpoint'))
        restore_config.set_restore_dir(load_from_curr_session=True)

        #inference_op[0] = tf.reshape(inference_op[0], [-1, len(inv_wordmap)])
        #inference_op[1] = tf.reshape(inference_op[1], [-1, len(inv_wordmap)])

        logits = []
        for b in range(BATCH_SIZE):
            buff = []
            for t in range(seq_len_to_test):

                #tmp_col = tf.reshape(tf.tile( tf.reshape(fake[1][b][t], [-1]), [len(inv_wordmap)]), \
                #                    [len(inv_wordmap), len(inv_wordmap)])
                tmp_col = tf.reshape(tf.tile( tf.reshape(inference_op[1][b][t], [-1]), [len(inv_wordmap)]), \
                                     [len(inv_wordmap), len(inv_wordmap)])
                tmp_col = tf.nn.softmax(tmp_col)

                #tmp_row = tf.reshape(tf.exp(fake[0][b][t]), [-1,1])
                tmp_row = tf.reshape(tf.nn.softmax(inference_op[0][b][t]), [-1,1])
                #tmp_row = tf.reshape(inference_op[0][b][t], [-1,1])

                tmp = tmp_col + tmp_row
                #tmp = tf.matmul(tf.reshape(inference_op[0][b][t], [-1,1]), tf.reshape(inference_op[1][b][t], [1,-1]))
                tmp = tf.reshape(tmp, [1,-1])
                #tmp = tf.concat([tmp, tf.zeros([1,1], dtype=tf.float32)], -1)

                buff.append(tmp)
            logits.append(tf.reshape(buff, [-1, len(inv_wordmap)**2]))

        logits = tf.reshape(logits, [BATCH_SIZE, seq_len_to_test, len(inv_wordmap)**2])
        logits = tf.transpose(logits, [1,0,2])
        #logits = tf.nn.softmax(logits)
        #_logits = tf.exp(logits)
        #_logits = tf.nn.softmax(logits)
        _logits = logits
        #_logits = tf.log(logits)
        #print(logits)

        length = tf.multiply(tf.ones([BATCH_SIZE], dtype=tf.int32),tf.constant(seq_len_to_test,dtype=tf.int32))

        #length = tf.multiply(tf.ones([BATCH_SIZE], dtype=tf.int32),tf.constant(10,dtype=tf.int32))
        print(session.run(length))

        #res = tf.nn.ctc_beam_search_decoder(_logits, length, beam_width=10, merge_repeated=False)
        res = tf.nn.ctc_greedy_decoder(_logits, length, merge_repeated=False)

        paths = tf.sparse_tensor_to_dense(res[0][0], default_value=-1)   # Shape: [batch_size, max_sequence_len]
        for batch in range(BATCH_SIZE):


            infer, logs, logit, i_op = session.run([paths, res[1], _logits, inference_op[0]])

            #for x in range(1):
            #   for y in range(20):
            #       for z in range(62501):
            #           assert (logit[x][y][z] != 0)

            print(logit)
            for i in range(len(infer)):
                for j in range(len(infer[0])):
                    ind = infer[i][j]
                    if infer[i][j] == -1:
                        break

                    row = ind // len(inv_wordmap)
                    col = ind % len(inv_wordmap)
                    print( inv_wordmap[row][col] , end=' ')

                print('')

            #print(infer)
            #infer_r, infer_c = infer



    session.close()
示例#4
0
def run(iterations, seq_length, is_first, wordmap, inv_wordmap, prev_seq_length, meta_iter, max_meta_iter):
    '''
    Performs a single run of a single meta iter of curriculum training
    also performs reallocate at the end

    iterations = the number of minibatches
    seq_length = the length of the sequences to create
    is_first = if we need to load from ckpt
    wordmap = for getting the naw from
    inv_wordmap = for decoding the samples
    prev_seq_len = for chosing the ckpt to load from
    meta_iter = the round of the seq_length training
    max_meta_iter = used to make sure the tensorboard graphs log using the correct global step
    '''
    # make sure tf does not take up all the GPU mem (model size is not that large so there is unlikely to be fragmentation problems)
    config = tf.ConfigProto(log_device_placement=False)
    config.gpu_options.allow_growth = True

    # first one so copy from the initial location file
    if is_first and os.path.isfile('locations/word-0.locations') or FLAGS.WORDVECS is not None:
        copyfile('locations/word-0.locations', 'locations/word-%d.locations' % seq_length)


    # load the lines from the dataset along with the current wordmap and inv_wordmap
    lines, wordmap, inv_wordmap = load_dataset(seq_length=seq_length, n_examples=FLAGS.MAX_N_EXAMPLES)

    if not os.path.isfile('locations/word-0.locations'):
        copyfile('locations/word-%d.locations' % seq_length, 'locations/word-0.locations')


    # placeholders for the input from the datset
    real_inputs_discrete = [tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length]), \
                              tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])]

    # global step
    global_step = tf.Variable(iterations, trainable=False, name='global-step')

    # indec of <naw> in the map
    naw_r, naw_c = wordmap['<naw>'][0], wordmap['<naw>'][1]

    session = tf.Session(config=config)

    # start the generator to get minibatches from the dataset
    gen = inf_train_gen(lines, wordmap, seq_length)

    # define the network
    #disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op, realloc_op, d_cost_fake, d_cost_real = define_objective(session, wordmap, real_inputs_discrete, seq_length, naw_r, naw_c, gen)
    # get the summaries, optimizers, and saver

    optim_costs, discs, ops, log_costs, embeddings = define_objective(session, wordmap, real_inputs_discrete, seq_length, naw_r, naw_c, gen)

    #embed_config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
    #r_embedding = embed_config.embeddings.add()
    #r_embedding.tensor_name = embeddings.name
    #r_embedding.metadata_path = '../../loations/metadata.tsv'

    disc_cost, gen_cost = optim_costs
    disc_fake, disc_real, disc_on_inference = discs
    fake_inputs, inference_op, realloc_op = ops
    d_cost_fake, d_cost_real = log_costs

    merged, train_writer = define_summaries(d_cost_real, d_cost_fake, gen_cost, disc_cost)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step)
    saver = tf.train.Saver(tf.trainable_variables())


    with session.as_default():
        # write the graph and init vars
        train_writer.add_graph(session.graph)
        session.run(tf.global_variables_initializer())

        # if not is_first, we need to load from ckpt
        if is_first and PRETRAIN:
            optimistic_restore(session, tf.train.latest_checkpoint('pretrain', 'checkpoint'))
            restore_config.set_restore_dir(load_from_curr_session=True)
        elif not is_first:
            internal_checkpoint_dir = get_internal_checkpoint_dir(prev_seq_length)
            optimistic_restore(session, tf.train.latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
            restore_config.set_restore_dir(load_from_curr_session=True)



        # cool progress bar
        with tqdm(total=iterations, ncols=150) as pbar:

            # train loop
            for iteration in range(iterations):
                # Train critic first
                for _ in range(CRITIC_ITERS):
                    _data = next(gen)
                    _disc_cost, _, _ = session.run([disc_cost, disc_train_op, disc_real], \
                                                   feed_dict={real_inputs_discrete[0]: _data[0], \
                                                              real_inputs_discrete[1]: _data[1]})
                # Train generator
                for _ in range(GEN_ITERS):
                    _data = next(gen)
                    _, _g_cost = session.run([gen_train_op, gen_cost], feed_dict={real_inputs_discrete[0]:_data[0], \
                                                                                  real_inputs_discrete[1]: _data[1]})


                # update progress bat with costs and inc its counter by 1
                pbar.set_description("disc cost %f \t gen cost %f \t"% (_disc_cost, _g_cost))
                pbar.update(1)

                # write sumamries
                if iteration % 100 == 99:
                    _data = next(gen)
                    summary_str = session.run(merged, feed_dict={real_inputs_discrete[0]:_data[0], \
                                                                 real_inputs_discrete[1]:_data[1]})

                    iii = meta_iter*iterations + iteration + ((seq_length-1)*iterations*max_meta_iter)

                    train_writer.add_summary(summary_str, \
                                             global_step=iii)

                    # generate and log ouput from training
                    fake_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_wordmap,
                                                                                          fake_inputs,
                                                                                          disc_fake,
                                                                                          gen,
                                                                                          real_inputs_discrete,
                                                                                          feed_gt=True, method='argmax')
                    log_samples(fake_samples, fake_scores, iii, seq_length, "gen-w-gt")

                    # generate and log output from inference
                    test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_wordmap,
                                                                                        inference_op,
                                                                                          disc_on_inference,
                                                                                          gen,
                                                                                          real_inputs_discrete,
                                                                                          feed_gt=False, method='argmax')
                    log_samples(test_samples, fake_scores, iii, seq_length, "gen-no-gt")


        #*********************************************************************************
        #if seq_length >= 40:
        # copy current location file to next sequence_length as we are going to gen sequences of length seq_len +1 for realloc
        if FLAGS.WORDVECS is None:

            copyfile('locations/word-%d.locations' % (seq_length), 'locations/word-%d.locations' % (seq_length+1))
            if seq_length >= 0:
                # get the lines, note that wordmap and inv_wormap stay the same
                lines, _, _ = load_dataset(seq_length=48, n_examples=FLAGS.MAX_N_EXAMPLES,\
                                           no_write=True)

                # start generator and perform the reallocation
                gen = inf_realloc_gen(lines, wordmap, seq_length+1)
                perf_reallocate(int(1000000), session, inv_wordmap, realloc_op, \
                                gen, seq_length, real_inputs_discrete, naw_r, naw_c)

                #  realloc creates the new location file in seq_length +1 so we move it back if we are not at the last meta_iter
                if meta_iter != max_meta_iter -1:
                    copyfile('locations/word-%d.locations' % (seq_length+1), 'locations/word-%d.locations' % (seq_length))
                    #os.remove('locations/word-%d.locations' % (seq_length+1))
                    os.remove('locations/word-%d.locations.string' % (seq_length+1))


        # save the ckpt and close the session because we need to reset the graph
        saver.save(session, get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
示例#5
0
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception('Please specify path to data directory in single_length_train.py!')

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False, n_examples=FLAGS.MAX_N_EXAMPLES)

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op, other_ops = define_objective(charmap,real_inputs_discrete, seq_length,
        gan_type=FLAGS.GAN_TYPE, rnn_cell=RNN_CELL)


    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(
        disc_cost, gen_cost, global_step, FLAGS.DISC_LR, FLAGS.GEN_LR)

    saver = tf.train.Saver(tf.trainable_variables())

    # Use JIT XLA compilation to speed up calculations
    config=tf.ConfigProto(
        log_device_placement=False, allow_soft_placement=True)
    config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    with tf.Session(config=config) as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length)
            model_and_data_serialization.optimistic_restore(session,
                latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

        gen = inf_train_gen(lines, charmap)
        _gen_cost_list = []
        _disc_cost_list = []
        _step_time_list = []

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data = next(gen)

                if FLAGS.GAN_TYPE.lower() == "fgan":
                    _disc_cost, _, real_scores, _ = session.run(
                    [disc_cost, disc_train_op, disc_real,
                        other_ops["alpha_optimizer_op"]],
                    feed_dict={real_inputs_discrete: _data}
                    )

                elif FLAGS.GAN_TYPE.lower() == "wgan":
                    _disc_cost, _, real_scores = session.run(
                    [disc_cost, disc_train_op, disc_real],
                    feed_dict={real_inputs_discrete: _data}
                    )

                else:
                    raise ValueError(
                        "Appropriate gan type not selected: {}".format(FLAGS.GAN_TYPE))
                _disc_cost_list.append(_disc_cost)



            # Train G
            for i in range(GEN_ITERS):
                _data = next(gen)
                # in Fisher GAN, paper measures convergence by gen_cost instead of disc_cost
                # To measure convergence, gen_cost should start at a positive number and decrease
                # to zero. The lower, the better.
                _gen_cost, _ = session.run([gen_cost, gen_train_op], feed_dict={real_inputs_discrete: _data})
                _gen_cost_list.append(_gen_cost)

            _step_time_list.append(time.time() - start_time)

            if FLAGS.PRINT_EVERY_STEP:
                print("iteration %s/%s"%(iteration, iterations))
                print("disc cost {}"%_disc_cost)
                print("gen cost {}".format(_gen_cost))
                print("total step time {}".format(time.time() - start_time))


            # Summaries
            if iteration % FLAGS.PRINT_ITERATION == FLAGS.PRINT_ITERATION-1:
                _data = next(gen)
                summary_str = session.run(
                    merged,
                    feed_dict={real_inputs_discrete: _data}
                )

                tf.logging.warn("iteration %s/%s"%(iteration, iterations))
                tf.logging.warn("disc cost {} gen cost {} average step time {}".format(
                    np.mean(_disc_cost_list), np.mean(_gen_cost_list), np.mean(_step_time_list)))
                _gen_cost_list, _disc_cost_list, _step_time_list = [], [], []

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, fake_inputs, disc_fake, gen, real_inputs_discrete,feed_gt=True)

                log_samples(fake_samples, fake_scores, iteration, seq_length, "train")
                log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length,
                             "gt")
                test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session,
                  inv_charmap,
                  inference_op,
                  disc_on_inference,
                  gen,
                  real_inputs_discrete,
                  feed_gt=False)
                # disc_on_inference, inference_op
                log_samples(test_samples, fake_scores, iteration, seq_length, "test")


            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1:
                saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")

        saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
示例#6
0
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length):
    if len(DATA_DIR) == 0:
        raise Exception('Please specify path to data directory in single_length_train.py!')

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
                                                            n_examples=FLAGS.MAX_N_EXAMPLES)

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(charmap,
                                                                                                            real_inputs_discrete,
                                                                                                            seq_length)
    merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
    disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step)

    saver = tf.train.Saver(tf.trainable_variables())

    with tf.Session() as session:

        session.run(tf.initialize_all_variables())
        if not is_first:
            print("Loading previous checkpoint...")
            internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length)
            model_and_data_serialization.optimistic_restore(session,
                                                            latest_checkpoint(internal_checkpoint_dir, "checkpoint"))

            restore_config.set_restore_dir(
                load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

        gen = inf_train_gen(lines, charmap)

        for iteration in range(iterations):
            start_time = time.time()

            # Train critic
            for i in range(CRITIC_ITERS):
                _data = next(gen)
                _disc_cost, _, real_scores = session.run(
                    [disc_cost, disc_train_op, disc_real],
                    feed_dict={real_inputs_discrete: _data}
                )

            # Train G
            for i in range(GEN_ITERS):
                _data = next(gen)
                _ = session.run(gen_train_op, feed_dict={real_inputs_discrete: _data})

            print("iteration %s/%s"%(iteration, iterations))
            print("disc cost %f"%_disc_cost)

            # Summaries
            if iteration % 100 == 99:
                _data = next(gen)
                summary_str = session.run(
                    merged,
                    feed_dict={real_inputs_discrete: _data}
                )

                train_writer.add_summary(summary_str, global_step=iteration)
                fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
                                                                                                              fake_inputs,
                                                                                                              disc_fake,
                                                                                                              gen,
                                                                                                              real_inputs_discrete,
                                                                                                              feed_gt=True)

                log_samples(fake_samples, fake_scores, iteration, seq_length, "train")
                log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length,
                             "gt")
                test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
                                                                                      inference_op,
                                                                                      disc_on_inference,
                                                                                      gen,
                                                                                      real_inputs_discrete,
                                                                                      feed_gt=False)
                # disc_on_inference, inference_op
                log_samples(test_samples, fake_scores, iteration, seq_length, "test")


            if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1:
                saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")

        saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
        session.close()
示例#7
0
def evaluate(seq_length, N, charmap, inv_charmap):

    lines, _, _ = model_and_data_serialization.load_dataset(
        seq_length=seq_length,
        b_charmap=False,
        b_inv_charmap=False,
        n_examples=FLAGS.MAX_N_EXAMPLES,
        dataset='heldout')

    real_inputs_discrete = tf.placeholder(tf.int32,
                                          shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, train_pred, train_pred_for_eval, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(
        charmap, real_inputs_discrete, seq_length)

    # train_pred -> run session
    train_pred_all = np.zeros([N, BATCH_SIZE, seq_length, train_pred.shape[2]],
                              dtype=np.float32)

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())

    # load checkpoints
    internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(
        0)
    model_and_data_serialization.optimistic_restore(
        sess, latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
    restore_config.set_restore_dir(
        load_from_curr_session=True
    )  # global param, always load from curr session after finishing the first seq

    BPC_list = []

    for start_line in range(0, len(lines) - BATCH_SIZE + 1, BATCH_SIZE):
        t0 = time.time()
        _data = np.array([[charmap[c] for c in l]
                          for l in lines[start_line:start_line + BATCH_SIZE]])

        # rand N noise vectors and for each one - calculate train_pred.
        for i in range(N):
            train_pred_i = sess.run(train_pred_for_eval,
                                    feed_dict={real_inputs_discrete: _data})
            train_pred_all[i, :, :, :] = train_pred_i

        # take average on each time step (first dimension)
        train_pred_average = np.mean(train_pred_all, axis=0)

        # compute BPC (char-based perplexity)
        train_pred_average_2d = train_pred_average.reshape([
            train_pred_average.shape[0] * train_pred_average.shape[1],
            train_pred_average.shape[2]
        ])
        real_data = _data.reshape([_data.shape[0] * _data.shape[1]])

        BPC = 0

        epsilon = 1e-20
        for i in range(real_data.shape[0]):
            BPC -= np.log2(train_pred_average_2d[i, real_data[i]] + epsilon)

        BPC /= real_data.shape[0]
        print("BPC of start_line %d/%d = %.2f" % (start_line, len(lines), BPC))
        print("t_iter = %.2f" % (time.time() - t0))
        BPC_list.append(BPC)
        np.save('BPC_list_temp.npy', BPC_list)

    BPC_final = np.mean(BPC_list)
    print("BPC_final = %.2f\n" % (BPC_final))
    np.save('BPC_list.npy', BPC_list)
    np.save('BPC_final.npy', BPC_final)
示例#8
0
def evaluate(seq_length, N, charmap):

    lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
                                                            n_examples=FLAGS.MAX_N_EXAMPLES, dataset='heldout')

    real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])

    global_step = tf.Variable(0, trainable=False)
    disc_cost, gen_cost, train_pred, train_pred_for_eval, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(charmap,
                                                                                                            real_inputs_discrete,
                                                                                                            seq_length)

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())

    # load checkpoints
    internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(0)
    model_and_data_serialization.optimistic_restore(sess,
                                                    latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
    restore_config.set_restore_dir(
        load_from_curr_session=True)  # global param, always load from curr session after finishing the first seq

    cnt = 0
    Linf_norm_values = []
    N_graph_values = []

    print("N = %d\n" % (N))
    # train_pred -> run session
    train_pred_all = np.zeros([N, BATCH_SIZE, seq_length, train_pred.shape[2]],
                            dtype=np.float32)

    for start_line in range(1):
        _data = np.array([[charmap[c] for c in l] for l in lines[start_line:start_line + BATCH_SIZE]])

        # rand N noise vectors and for each one - calculate train_pred.
        for i in range(N):
            train_pred_i = sess.run(train_pred_for_eval, feed_dict={real_inputs_discrete: _data})
            train_pred_all[i,:,:,:] = train_pred_i

    N_run = range(10, N+10, 10)
    train_pred_average_prev = np.mean(train_pred_all[0:N_run[0]], axis=0)

    for N_curr in N_run[1:]:
        N_graph_values.append(N_curr)
        train_pred_average = np.mean(train_pred_all[0:N_curr], axis=0)
        train_pred_average_diff = train_pred_average - train_pred_average_prev
        Linf_norm_diff = np.linalg.norm(train_pred_average_diff, axis=2, ord=np.inf)
        Linf_norm_values.append(np.mean(Linf_norm_diff))
        cnt += 1
        train_pred_average_prev = train_pred_average

    plt.figure()
    plt.plot(N_graph_values, Linf_norm_values)
    plt.title('Approximated error between empiric and real mean (Linf norm)', size=30)
    plt.xlabel('Number of random noise vectors (N)', size=16)
    plt.ylabel('Error (L_inf norm)', size=16)
示例#9
0
def evaluate(EVAL_FLAGS):

    # load dict & params
    _, charmap, inv_charmap = model_and_data_serialization.load_dataset(
        seq_length=32, b_lines=False)
    seq_length = EVAL_FLAGS.seq_len
    N = EVAL_FLAGS.num_samples

    ckp_list, config_list = get_models_list()

    print("ALL MODELS: %0s" % ckp_list)

    for ckp_path, config_path in zip(ckp_list, config_list):

        model_name = "%0s_%0s" % (ckp_path.split('/')[2],
                                  ckp_path.split('/')[4])

        #filter by model name
        if not model_name.startswith(EVAL_FLAGS.prefix_filter):
            print("NOT EVALUATING [%0s]" % model_name)
            continue

        tf.reset_default_graph()

        print("EVALUATING [%0s]" % model_name)
        print("restoring config:")
        FLAGS.DISC_STATE_SIZE = restore_param_from_config(
            config_path, 'DISC_STATE_SIZE')
        FLAGS.GEN_STATE_SIZE = restore_param_from_config(
            config_path, 'GEN_STATE_SIZE')
        print("DISC_STATE_SIZE [%0d]" % FLAGS.DISC_STATE_SIZE)
        print("GEN_STATE_SIZE [%0d]" % FLAGS.GEN_STATE_SIZE)

        lines, _, _ = model_and_data_serialization.load_dataset(
            seq_length=seq_length,
            b_charmap=False,
            b_inv_charmap=False,
            n_examples=FLAGS.MAX_N_EXAMPLES,
            dataset='heldout')

        real_inputs_discrete = tf.placeholder(tf.int32,
                                              shape=[BATCH_SIZE, seq_length])

        global_step = tf.Variable(0, trainable=False)
        disc_cost, gen_cost, train_pred, train_pred_for_eval, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(
            charmap, real_inputs_discrete, seq_length)

        # train_pred -> run session
        train_pred_all = np.zeros(
            [N, BATCH_SIZE, seq_length, train_pred.shape[2]], dtype=np.float32)

        sess = tf.Session()
        sess.run(tf.initialize_all_variables())

        # load checkpoints
        # internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(0)
        internal_checkpoint_dir = ckp_path
        model_and_data_serialization.optimistic_restore(
            sess, latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
        restore_config.set_restore_dir(
            load_from_curr_session=True
        )  # global param, always load from curr session after finishing the first seq

        # create samples
        print('start creating samples..')
        test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(
            sess,
            inv_charmap,
            inference_op,
            disc_on_inference,
            None,
            real_inputs_discrete,
            feed_gt=False)

        log_samples(test_samples, fake_scores, 666, seq_length, "test")

        print('starting BPC eval...')
        BPC_list = []

        if EVAL_FLAGS.short_run:
            end_line = BATCH_SIZE * 100  # 100 batches
        else:
            end_line = len(lines) - BATCH_SIZE + 1  # all data

        for start_line in range(0, end_line, BATCH_SIZE):
            t0 = time.time()
            _data = np.array([[charmap[c] for c in l]
                              for l in lines[start_line:start_line +
                                             BATCH_SIZE]])

            # rand N noise vectors and for each one - calculate train_pred.
            for i in range(N):
                train_pred_i = sess.run(
                    train_pred_for_eval,
                    feed_dict={real_inputs_discrete: _data})
                train_pred_all[i, :, :, :] = train_pred_i

            # take average on each time step (first dimension)
            train_pred_average = np.mean(train_pred_all, axis=0)

            # compute BPC (char-based perplexity)
            train_pred_average_2d = train_pred_average.reshape([
                train_pred_average.shape[0] * train_pred_average.shape[1],
                train_pred_average.shape[2]
            ])
            real_data = _data.reshape([_data.shape[0] * _data.shape[1]])

            BPC = 0

            epsilon = 1e-20
            for i in range(real_data.shape[0]):
                BPC -= np.log2(train_pred_average_2d[i, real_data[i]] +
                               epsilon)

            BPC /= real_data.shape[0]
            print("BPC of start_line %d/%d = %.2f" %
                  (start_line, len(lines), BPC))
            print("t_iter = %.2f" % (time.time() - t0))
            BPC_list.append(BPC)
            np.save('BPC_list_temp.npy', BPC_list)

        BPC_final = np.mean(BPC_list)
        print("[%0s]BPC_final = %.2f\n" % (ckp_path, BPC_final))
        np.save("%0s_BPC_list.npy" % model_name, BPC_list)
        np.save("%0s_BPC_final.npy" % model_name, BPC_final)