Exemplo n.º 1
0
def run_training(train_ids,
                 test_ids,
                 tf_savedir,
                 model_params,
                 max_time=100,
                 batch_size=256,
                 learning_rate=0.002,
                 num_epochs=20):
    #V = len(words_to_ids.keys())
    # Training parameters
    ## add parameter sets for each attack/defense configuration
    #max_time = 25
    #batch_size = 100
    #learning_rate = 0.01
    #num_epochs = 10

    # Model parameters
    #model_params = dict(V=vocab.size,
    #H=200,
    #softmax_ns=200,
    #num_layers=2)
    #model_params = dict(V=len(words_to_ids.keys()),
    #H=1024,
    #softmax_ns=len(words_to_ids.keys()),
    #num_layers=2)
    #model_params = dict(V=V, H=H, softmax_ns=softmax_ns, num_layers=num_layers)

    #TF_SAVEDIR = "/tmp/artificial_hotel_reviews/a4_model"
    TF_SAVEDIR = tf_savedir
    checkpoint_filename = os.path.join(TF_SAVEDIR, "rnnlm")
    trained_filename = os.path.join(TF_SAVEDIR, "rnnlm_trained")

    # Will print status every this many seconds
    #print_interval = 5
    print_interval = 30

    lm = rnnlm.RNNLM(**model_params)
    lm.BuildCoreGraph()
    lm.BuildTrainGraph()

    # Explicitly add global initializer and variable saver to LM graph
    with lm.graph.as_default():
        initializer = tf.global_variables_initializer()
        saver = tf.train.Saver()

    # Clear old log directory
    shutil.rmtree(TF_SAVEDIR, ignore_errors=True)
    if not os.path.isdir(TF_SAVEDIR):
        os.makedirs(TF_SAVEDIR)

    with tf.Session(graph=lm.graph) as session:
        # Seed RNG for repeatability
        #tf.set_random_seed(42)

        session.run(initializer)

        #check trainable variables
        #variables_names = [v.name for v in tf.trainable_variables()]
        #values = session.run(variables_names)
        #for k, v in zip(variables_names, values):
        #print("Variable: ", k)
        #print("Shape: ", v.shape)
        #print(v)

        for epoch in range(1, num_epochs + 1):
            t0_epoch = time.time()
            bi = utils.rnnlm_batch_generator(train_ids, batch_size, max_time)
            print("[epoch {:d}] Starting epoch {:d}".format(epoch, epoch))
            # Run a training epoch.
            run_epoch(lm,
                      session,
                      batch_iterator=bi,
                      train=True,
                      verbose=True,
                      tick_s=10,
                      learning_rate=learning_rate)

            print("[epoch {:d}] Completed in {:s}".format(
                epoch, utils.pretty_timedelta(since=t0_epoch)))

            # Save a checkpoint
            saver.save(session, checkpoint_filename, global_step=epoch)

            ##
            # score_dataset will run a forward pass over the entire dataset
            # and report perplexity scores. This can be slow (around 1/2 to
            # 1/4 as long as a full epoch), so you may want to comment it out
            # to speed up training on a slow machine. Be sure to run it at the
            # end to evaluate your score.
            print("[epoch {:d}]".format(epoch), end=" ")
            score_dataset(lm, session, train_ids, name="Train set")
            print("[epoch {:d}]".format(epoch), end=" ")
            score_dataset(lm, session, test_ids, name="Test set")
            print("")

        # Save final model
        saver.save(session, trained_filename)
        return trained_filename
Exemplo n.º 2
0
 def test_pretty_timedelta(self):
     since = 12345
     until = since + 3934  # 1 hr, 5 min, 34 sec
     res = utils.pretty_timedelta(since=since, until=until)
     self.assertEqual(res, "1:05:34")