def run_training(train_ids, test_ids, tf_savedir, model_params, max_time=100, batch_size=256, learning_rate=0.002, num_epochs=20): #V = len(words_to_ids.keys()) # Training parameters ## add parameter sets for each attack/defense configuration #max_time = 25 #batch_size = 100 #learning_rate = 0.01 #num_epochs = 10 # Model parameters #model_params = dict(V=vocab.size, #H=200, #softmax_ns=200, #num_layers=2) #model_params = dict(V=len(words_to_ids.keys()), #H=1024, #softmax_ns=len(words_to_ids.keys()), #num_layers=2) #model_params = dict(V=V, H=H, softmax_ns=softmax_ns, num_layers=num_layers) #TF_SAVEDIR = "/tmp/artificial_hotel_reviews/a4_model" TF_SAVEDIR = tf_savedir checkpoint_filename = os.path.join(TF_SAVEDIR, "rnnlm") trained_filename = os.path.join(TF_SAVEDIR, "rnnlm_trained") # Will print status every this many seconds #print_interval = 5 print_interval = 30 lm = rnnlm.RNNLM(**model_params) lm.BuildCoreGraph() lm.BuildTrainGraph() # Explicitly add global initializer and variable saver to LM graph with lm.graph.as_default(): initializer = tf.global_variables_initializer() saver = tf.train.Saver() # Clear old log directory shutil.rmtree(TF_SAVEDIR, ignore_errors=True) if not os.path.isdir(TF_SAVEDIR): os.makedirs(TF_SAVEDIR) with tf.Session(graph=lm.graph) as session: # Seed RNG for repeatability #tf.set_random_seed(42) session.run(initializer) #check trainable variables #variables_names = [v.name for v in tf.trainable_variables()] #values = session.run(variables_names) #for k, v in zip(variables_names, values): #print("Variable: ", k) #print("Shape: ", v.shape) #print(v) for epoch in range(1, num_epochs + 1): t0_epoch = time.time() bi = utils.rnnlm_batch_generator(train_ids, batch_size, max_time) print("[epoch {:d}] Starting epoch {:d}".format(epoch, epoch)) # Run a training epoch. run_epoch(lm, session, batch_iterator=bi, train=True, verbose=True, tick_s=10, learning_rate=learning_rate) print("[epoch {:d}] Completed in {:s}".format( epoch, utils.pretty_timedelta(since=t0_epoch))) # Save a checkpoint saver.save(session, checkpoint_filename, global_step=epoch) ## # score_dataset will run a forward pass over the entire dataset # and report perplexity scores. This can be slow (around 1/2 to # 1/4 as long as a full epoch), so you may want to comment it out # to speed up training on a slow machine. Be sure to run it at the # end to evaluate your score. print("[epoch {:d}]".format(epoch), end=" ") score_dataset(lm, session, train_ids, name="Train set") print("[epoch {:d}]".format(epoch), end=" ") score_dataset(lm, session, test_ids, name="Test set") print("") # Save final model saver.save(session, trained_filename) return trained_filename
def test_pretty_timedelta(self): since = 12345 until = since + 3934 # 1 hr, 5 min, 34 sec res = utils.pretty_timedelta(since=since, until=until) self.assertEqual(res, "1:05:34")