Пример #1
0
    def save_session(self):
        time_at_start_save = time.time()
        sys.stdout.write(
            '{}: [Step {}k  --  Took {:3.2f} s] '.format(
                datetime.now(),
                self.step // 1000,
                time_at_start_save - self.last_time))
        self.last_time = time_at_start_save

        Parameters.add_attr("CURRENT_STEP", self.step)
        a = time.time()
        Parameters.update()
        b = time.time()
        sys.stdout.write('[{:3.2f}s for json] '.format(b - a))

        save_file = path.join(
            Parameters.SESSION_SAVE_DIRECTORY,
            Parameters.SESSION_SAVE_FILENAME)
        if not path.exists(Parameters.SESSION_SAVE_DIRECTORY):
            makedirs(Parameters.SESSION_SAVE_DIRECTORY)
        a = time.time()
        self.tf_saver.save(self.tf_session, save_file)
        b = time.time()
        sys.stdout.write('[{:3.2f}s for tf] '.format(b - a))

        a = time.time()
        Plotter.save("out")
        b = time.time()
        sys.stdout.write('[{:3.2f}s for Plotter] '.format(b - a))
        a = time.time()
        self.memory.save_memory()
        b = time.time()
        sys.stdout.write('[{:3.2f}s for memory] '.format(b - a))
        post_save_time = time.time()
        sys.stdout.write(
            '[Required {:3.2f}s to save all] '.format(
                post_save_time -
                time_at_start_save))
        self.last_time = post_save_time
        elapsed_time = time.time() - self.initial_time
        remaining_seconds = elapsed_time * \
            (Parameters.MAX_STEPS - self.step) / (self.step - self.initial_step)
        print("eta: {}s".format((timedelta(seconds=remaining_seconds))))
Пример #2
0
def main():
    print("Loading wordvecs...")
    if utils.exists("glove", "glove.840B.300d.txt", "gutenberg"):
        words, wordvecs = utils.load_glove("glove", "glove.840B.300d.txt", "gutenberg")
    else:
        words, wordvecs = utils.load_glove("glove", "glove.840B.300d.txt", "gutenberg",
                                           set(map(clean_word, gutenberg.words())))

    wordvecs_norm = wordvecs / np.linalg.norm(wordvecs, axis=1).reshape(-1, 1)

    print("Loading corpus...")
    # Convert corpus into normed wordvecs, replacing any words not in vocab with zero vector
    sentences = [[wordvecs_norm[words[clean_word(word)]] if clean_word(word) in words.keys() else np.zeros(WORD_DIM)
                  for word in sentence]
                 for sentence in gutenberg.sents()]

    print("Processing corpus...")
    # Pad sentences shorter than SEQUENCE_LENGTH with zero vectors and truncate sentences longer than SEQUENCE_LENGTH
    s_train = list(map(pad_or_truncate, sentences))

    np.random.shuffle(s_train)

    # Truncate to multiple of BATCH_SIZE
    s_train = s_train[:int(len(s_train) / BATCH_SIZE) * BATCH_SIZE]

    s_train_idxs = np.arange(len(s_train))

    print("Generating graph...")
    network = NlpGan(learning_rate=LEARNING_RATE, d_dim_state=D_DIM_STATE, g_dim_state=G_DIM_STATE,
                     dim_in=WORD_DIM, sequence_length=SEQUENCE_LENGTH)

    plotter = Plotter([2, 1], "Loss", "Accuracy")
    plotter.plot(0, 0, 0, 0)
    plotter.plot(0, 0, 0, 1)
    plotter.plot(0, 0, 1, 0)
    plotter.plot(0, 1, 1, 0)

    #d_vars = [var for var in tf.trainable_variables() if 'discriminator' in var.name]
    saver = tf.train.Saver()

    with tf.Session() as sess:
        #eval(sess, network, words, wordvecs_norm, saver)

        sess.run(tf.global_variables_initializer())
        #resume(sess, saver, plotter, "GAN_9_SEQUENCELENGTH_10", 59)

        d_loss, g_loss = 0.0, 0.0
        for epoch in range(0, 10000000):
            print("Epoch %d" % epoch)

            np.random.shuffle(s_train_idxs)
            for batch in range(int(len(s_train_idxs) / BATCH_SIZE)):
                # select next random batch of sentences
                s_batch_real = [s_train[x] for x in s_train_idxs[batch:batch + BATCH_SIZE]] # shape (BATCH_SIZE, SEQUENCE_LENGTH, WORD_DIM)

                # reshape to (SEQUENCE_LENGTH, BATCH_SIZE, WORD_DIM) while preserving sentence order
                s_batch_real = np.array(s_batch_real).swapaxes(0, 1)

                if d_loss - g_loss > MAX_LOSS_DIFF and False:
                    output_dict = sess.run(
                        network.get_fetch_dict('d_loss', 'd_train', 'g_loss'),
                        network.get_feed_dict(inputs=s_batch_real, input_dropout=D_KEEP_PROB)
                    )
                elif g_loss - d_loss > MAX_LOSS_DIFF and False:
                    output_dict = sess.run(
                        network.get_fetch_dict('d_loss', 'g_loss', 'g_train'),
                        network.get_feed_dict(inputs=s_batch_real, input_dropout=D_KEEP_PROB)
                    )
                else:
                    output_dict = sess.run(
                        network.get_fetch_dict('d_loss', 'd_train', 'g_loss', 'g_train'),
                        network.get_feed_dict(inputs=s_batch_real, input_dropout=D_KEEP_PROB,
                                              instance_variance=INSTANCE_VARIANCE)
                    )

                d_loss, g_loss = output_dict['d_loss'], output_dict['g_loss']

                if batch % 10 == 0:
                    print("Finished training batch %d / %d" % (batch, int(len(s_train) / BATCH_SIZE)))
                    print("Discriminator Loss: %f" % output_dict['d_loss'])
                    print("Generator Loss: %f" % output_dict['g_loss'])
                    plotter.plot(epoch + (batch / int(len(s_train) / BATCH_SIZE)), d_loss, 0, 0)
                    plotter.plot(epoch + (batch / int(len(s_train) / BATCH_SIZE)), g_loss, 0, 1)

                if batch % 100 == 0:
                    eval = sess.run(
                        network.get_fetch_dict('g_outputs', 'd_accuracy'),
                        network.get_feed_dict(inputs=s_batch_real, input_dropout=1.0,
                                              instance_variance=INSTANCE_VARIANCE)
                    )
                    # reshape g_outputs to (BATCH_SIZE, SEQUENCE_LENGTH, WORD_DIM) while preserving sentence order
                    generated = eval['g_outputs'].swapaxes(0, 1)
                    for sentence in generated[:3]:
                        for wordvec in sentence:
                            norm = np.linalg.norm(wordvec)
                            word, similarity = nearest_neighbor(words, wordvecs_norm, wordvec / norm)
                            print("{}({:4.2f})".format(word, similarity))
                        print('\n---------')
                    print("Total Accuracy: %f" % eval['d_accuracy'])
                    plotter.plot(epoch + (batch / int(len(s_train) / BATCH_SIZE)), eval['d_accuracy'], 1, 0)

            saver.save(sess, './checkpoints/{}.ckpt'.format(SAVE_NAME),
                       global_step=epoch)
            plotter.save(SAVE_NAME)
Пример #3
0
    beg_yr = int(sys.argv[2])
    end_yr = int(sys.argv[3])
    nc_var = sys.argv[4]
    obs_pattern = sys.argv[5]
    obs_nc_var = sys.argv[6]

    for yr in xrange(beg_yr, end_yr):
        pattern = os.path.join(nc_dir, '*' + str(yr) + '*.nc')
        #        for
        r = RegCMReader(pattern)
    value = r.get_value(nc_var).mean()
    time_limits = value.get_limits('time')
    crd_limits = value.get_latlonlimits()

    obs_r = CRUReader(obs_pattern)
    obs_value = obs_r.get_value(obs_nc_var,
                                imposed_limits={
                                    'time': time_limits
                                },
                                latlon_limits=crd_limits).mean()
    if obs_nc_var == "TMP":
        obs_value.to_K()

    value.regrid(obs_value.latlon)
    diff = obs_value - value
    plt = Plotter(diff)
    plt.plot(levels=(-5, 5))
    plt.show()
    plt.save('image', format='png')
    plt.close()
Пример #4
0
Example: 
./app.py 'data/Africa_SRF.1970*.nc' t2m 'obs/CRUTMP.CDF' TMP
"""

if __name__ == "__main__":
    if len(sys.argv) != 5:
        print usage
        sys.exit(1)
    pattern = sys.argv[1]
    nc_var = sys.argv[2]
    obs_pattern = sys.argv[3]
    obs_nc_var = sys.argv[4]

    r = RegCMReader(pattern)
    value = r.get_value(nc_var).mean()
    time_limits = value.get_limits('time')
    crd_limits = value.get_latlonlimits()

    obs_r = CRUReader(obs_pattern)
    obs_value = obs_r.get_value(obs_nc_var, imposed_limits={'time': time_limits}, latlon_limits=crd_limits).mean()
    if obs_nc_var == "TMP":
        obs_value.to_K()

    value.regrid(obs_value.latlon)
    diff = obs_value - value
    plt = Plotter(diff)
    plt.plot(levels = (-5, 5))
    plt.show()
    plt.save('image', format='png')
    plt.close()