Beispiel #1
0
  def decode(self):
    """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
    t0 = time.time()
    counter = FLAGS.decode_after
    while True:
      tf.reset_default_graph()
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return

      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = data.show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      # Run beam search to get best Hypothesis
      if FLAGS.ac_training:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn, self._dqn_sess, self._dqn_graph)
      else:
        best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)
      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      # Remove the [STOP] token from decoded_words, if necessary
      try:
        fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        counter += 1 # this is how many examples we've decoded
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)
          t0 = time.time()
def run_eval(model, batcher, vocab):
  """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
  model.build_graph() # build the graph
  saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
  sess = tf.Session(config=util.get_config())
  eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
  bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
  summary_writer = tf.summary.FileWriter(eval_dir)
  running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
  best_loss = None  # will hold the best loss achieved so far

  while True:
    _ = util.load_ckpt(saver, sess) # load a new checkpoint
    batch = batcher.next_batch() # get the next batch

    # run eval on the batch
    t0=time.time()
    results = model.run_eval_step(sess, batch)
    t1=time.time()
    tf.logging.info('seconds for batch: %.2f', t1-t0)

    # print the loss and coverage loss to screen
    loss = results['loss']
    tf.logging.info('loss: %f', loss)
    if FLAGS.coverage:
      coverage_loss = results['coverage_loss']
      tf.logging.info("coverage_loss: %f", coverage_loss)

    # add summaries
    summaries = results['summaries']
    train_step = results['global_step']
    summary_writer.add_summary(summaries, train_step)

    # calculate running avg loss
    running_avg_loss = calc_running_avg_loss(np.asscalar(loss), running_avg_loss, summary_writer, train_step)

    # If running_avg_loss is best so far, save this checkpoint (early stopping).
    # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
    if best_loss is None or running_avg_loss < best_loss:
      tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
      saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
      best_loss = running_avg_loss

    # flush the summary writer every so often
    if train_step % 100 == 0:
      summary_writer.flush()
Beispiel #3
0
  def __init__(self, model, batcher, vocab, dqn = None):
    """Initialize decoder.

    Args:
      model: a Seq2SeqAttentionModel object.
      batcher: a Batcher object.
      vocab: Vocabulary object
    """
    self._model = model
    self._model.build_graph()
    self._batcher = batcher
    self._vocab = vocab
    self._saver = tf.train.Saver() # we use this to load checkpoints for decoding
    self._sess = tf.Session(config=util.get_config())

    if FLAGS.ac_training:
      self._dqn = dqn
      self._dqn_graph = tf.Graph()
      with self._dqn_graph.as_default():
        self._dqn.build_graph()
        self._dqn_saver = tf.train.Saver() # we use this to load checkpoints for decoding
        self._dqn_sess = tf.Session(config=util.get_config())
        _ = util.load_dqn_ckpt(self._dqn_saver, self._dqn_sess)

    # Load an initial checkpoint to use for decoding
    ckpt_path = util.load_ckpt(self._saver, self._sess, FLAGS.decode_from)

    if FLAGS.single_pass:
      # Make a descriptive decode directory name
      ckpt_name = "{}-ckpt-".format(FLAGS.decode_from) + ckpt_path.split('-')[
        -1]  # this is something of the form "ckpt-123456"
      self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name))
    else: # Generic decode dir name
      self._decode_dir = os.path.join(FLAGS.log_root, "decode")

    # Make the decode dir if necessary
    if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

    if FLAGS.single_pass:
      # Make the dirs to contain output written in the correct format for pyrouge
      self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
      if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir)
      self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
      if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
def convert_to_coverage_model():
  """Load non-coverage checkpoint, add initialized extra variables for coverage, and save as new checkpoint"""
  tf.logging.info("converting non-coverage model to coverage model..")

  # initialize an entire coverage model from scratch
  sess = tf.Session(config=util.get_config())
  print "initializing everything..."
  sess.run(tf.global_variables_initializer())

  # load all non-coverage weights from checkpoint
  saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name])
  print "restoring non-coverage variables..."
  curr_ckpt = util.load_ckpt(saver, sess)
  print "restored."

  # save this model and quit
  new_fname = curr_ckpt + '_cov_init'
  print "saving model to %s..." % (new_fname)
  new_saver = tf.train.Saver() # this one will save all variables that now exist
  new_saver.save(sess, new_fname)
  print "saved."
  exit()
def restore_best_model():
  """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
  tf.logging.info("Restoring bestmodel for training...")

  # Initialize all vars in the model
  sess = tf.Session(config=util.get_config())
  print "Initializing all variables..."
  sess.run(tf.initialize_all_variables())

  # Restore the best model from eval dir
  saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name])
  print "Restoring all non-adagrad variables from best model in eval dir..."
  curr_ckpt = util.load_ckpt(saver, sess, "eval")
  print "Restored %s." % curr_ckpt

  # Save this model to train dir and quit
  new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
  new_fname = os.path.join(FLAGS.log_root, "train", new_model_name)
  print "Saving model to %s..." % (new_fname)
  new_saver = tf.train.Saver() # this saver saves all variables that now exist, including Adagrad variables
  new_saver.save(sess, new_fname)
  print "Saved."
  exit()
Beispiel #6
0
	def __init__(self, model, batcher, vocab):
		"""Initialize decoder.

		Args:
			model: a Seq2SeqAttentionModel object.
			batcher: a Batcher object.
			vocab: Vocabulary object
		"""
		self._model = model
		self._model.build_graph()
		self._batcher = batcher
		self._vocab = vocab
		self._saver = tf.train.Saver() # we use this to load checkpoints for decoding
		self._sess = tf.Session(config=util.get_config())

		# Load an initial checkpoint to use for decoding
		ckpt_path = util.load_ckpt(self._saver, self._sess)

		if FLAGS.single_pass:
			# Make a descriptive decode directory name
			ckpt_name = "ckpt-" + ckpt_path.split('-')[-1] # this is something of the form "ckpt-123456"
			self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name))
			if os.path.exists(self._decode_dir):
				raise Exception("single_pass decode directory %s should not already exist" % self._decode_dir)

		else: # Generic decode dir name
			self._decode_dir = os.path.join(FLAGS.log_root, "decode")

		# Make the decode dir if necessary
		if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

		if FLAGS.single_pass:
			# Make the dirs to contain output written in the correct format for pyrouge
			self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
			if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir)
			self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
			if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)
Beispiel #7
0
  def __init__(self, model, vocab,single_pass,hps,pointer_gen,log_root):
    """Initialize decoder.

    Args:
      model: a Seq2SeqAttentionModel object.
      batcher: a Batcher object.
      vocab: Vocabulary object
    """
    self._model = model
    self._model.build_graph()
    # self._batcher = batcher
    self._vocab = vocab
    self._saver = tf.train.Saver() # we use this to load checkpoints for decoding
    self._sess = tf.Session(config=util.get_config())
    self.single_pass=single_pass
    self. max_dec_steps=hps.max_dec_steps
    self.min_dec_steps=hps.min_dec_steps
    self.max_dec_steps=hps.max_dec_steps
    self.beam_size=hps.beam_size
    self.pointer_gen=pointer_gen

    # Load an initial checkpoint to use for decoding
    ckpt_path = util.load_ckpt(self._saver, self._sess,log_root)
    print ckpt_path
Beispiel #8
0
def restore_best_model():
    """Load bestmodel file from eval directory, add variables for adagrad, and save to train directory"""
    tf.logging.info("Restoring bestmodel for training...")
    # Initialize all vars in the model
    sess = tf.Session(config=util.get_config())
    print("Initializing all variables...")
    sess.run(tf.initialize_all_variables())

    # Restore the best model from eval dir
    saver = tf.train.Saver(
        [v for v in tf.all_variables() if "Adagrad" not in v.name])
    print("Restoring all non-adagrad variables from best model in eval dir...")
    curr_ckpt = util.load_ckpt(saver, sess, "eval")
    print("Restored %s." % curr_ckpt)

    # Save this model to train dir and quit
    new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
    new_fname = os.path.join(FLAGS.log_root, "train", new_model_name)
    print("Saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver(
    )  # this saver saves all variables that now exist, including Adagrad variables
    new_saver.save(sess, new_fname)
    print("Saved.")
    exit()
Beispiel #9
0
def run_decoding(model, batcher):

    model.build_graph()  # build the graph
    saver = tf.train.Saver(
        max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())
    decode_dir = os.path.join(
        FLAGS.exp_name,
        "decode")  # make a subdir of the root dir for eval data
    batches = batcher.getBatches()

    _ = util.load_ckpt(saver, sess)  # load a new checkpoint

    epoch_avg_loss = 0.
    epoch_decode_steps = 0

    for batch in batches:
        results = model.run_decode_step(sess, batch)
        loss = results['loss']
        epoch_decode_steps += 1
        epoch_avg_loss = (epoch_avg_loss * (epoch_decode_steps - 1.) +
                          loss) / epoch_decode_steps

    print("Average loss %f" % (epoch_avg_loss))
Beispiel #10
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    tf.logging.set_verbosity(
        tf.logging.INFO)  # choose what level of logging you want
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    if not os.path.exists(FLAGS.log_root):
        if "train" in FLAGS.mode:
            os.makedirs(FLAGS.log_root)
        else:
            raise Exception(
                "Logdir %s doesn't exist. Run in train mode to create it." %
                (FLAGS.log_root))

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # print('FLAGS.flag_values_dict() ->', FLAGS.flag_values_dict())
    flags_dict = FLAGS.flag_values_dict()

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_sen_num', 'max_dec_steps', 'max_enc_steps'
    ]

    hps_dict = {}
    for key, val in flags_dict.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    print('hps_dict ->', json.dumps(hps_dict, ensure_ascii=False))
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
        'max_enc_sen_num', 'max_enc_seq_len'
    ]
    hps_dict = {}
    for key, val in flags_dict.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # # test
    # model_dis = Discriminator(hps_discriminator, vocab)
    # model_dis.build_graph()
    # sys.exit(0)
    # # test

    print('before load batcher...')
    # Create a batcher object that will create minibatches of data
    batcher = GenBatcher(vocab, hps_generator)
    print('after load batcher...')

    tf.set_random_seed(111)  # a seed value for randomness

    if hps_generator.mode == 'adversarial_train':
        print("Start pre-training......")
        model = Generator(hps_generator, vocab)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
        generated = Generated_sample(model, vocab, batcher, sess_ge)

        model_dis = Discriminator(hps_discriminator, vocab)
        dis_batcher = DisBatcher(hps_discriminator, vocab,
                                 "discriminator_train/positive/*",
                                 "discriminator_train/negative/*",
                                 "discriminator_test/positive/*",
                                 "discriminator_test/negative/*")
        sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
            model_dis)

        util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator")
        util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        if not os.path.exists("MLE"): os.mkdir("MLE")

        print("evaluate the diversity of MLE (decode based on sampling)")
        generated.generator_test_sample_example("MLE/" + "MLE_sample_positive",
                                                "MLE/" + "MLE_sample_negative",
                                                200)

        print(
            "evaluate the diversity of MLE (decode based on max probability)")
        generated.generator_test_max_example("MLE/" + "MLE_max_temp_positive",
                                             "MLE/" + "MLE_max_temp_negative",
                                             200)

        print("Start adversarial  training......")
        if not os.path.exists("train_sample_generated"):
            os.mkdir("train_sample_generated")
        if not os.path.exists("test_max_generated"):
            os.mkdir("test_max_generated")
        if not os.path.exists("test_sample_generated"):
            os.mkdir("test_sample_generated")

        whole_decay = False
        for epoch in range(10):
            batches = batcher.get_batches(mode='train')
            for step in range(int(len(batches) / 1000)):

                run_train_generator(
                    model, model_dis, sess_dis, batcher, dis_batcher,
                    batches[step * 1000:(step + 1) * 1000], sess_ge, saver_ge,
                    train_dir_ge, generated
                )  # (model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated):
                generated.generator_sample_example(
                    "train_sample_generated/" + str(epoch) + "epoch_step" +
                    str(step) + "_temp_positive", "train_sample_generated/" +
                    str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                    1000)
                # generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200)

                tf.logging.info("test performance: ")
                tf.logging.info("epoch: " + str(epoch) + " step: " + str(step))
                print(
                    "evaluate the diversity of DP-GAN (decode based on  max probability)"
                )
                generated.generator_test_sample_example(
                    "test_sample_generated/" + str(epoch) + "epoch_step" +
                    str(step) + "_temp_positive", "test_sample_generated/" +
                    str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                    200)
                print(
                    "evaluate the diversity of DP-GAN (decode based on sampling)"
                )
                generated.generator_test_max_example(
                    "test_max_generated/" + str(epoch) + "epoch_step" +
                    str(step) + "_temp_positive", "test_max_generated/" +
                    str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                    200)

                dis_batcher.train_queue = []
                dis_batcher.train_queue = []
                for i in range(epoch + 1):
                    for j in range(step + 1):
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_positive/*")
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_negative/*")
                dis_batcher.train_batch = dis_batcher.create_batches(
                    mode="train", shuffleis=True)

                # dis_batcher.valid_batch = dis_batcher.train_batch
                whole_decay = run_train_discriminator(
                    model_dis, 5, dis_batcher,
                    dis_batcher.get_batches(mode="train"), sess_dis, saver_dis,
                    train_dir_dis, whole_decay)

    elif hps_generator.mode == 'train_generator':
        print("Start pre-training......")
        model = Generator(hps_generator, vocab)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
        generated = Generated_sample(model, vocab, batcher, sess_ge)
        print("Start pre-training generator......")
        # this is an infinite loop until
        run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge,
                                train_dir_ge, generated)

        print("Generating negative examples......")
        generated.generator_train_negative_example()
        generated.generator_test_negative_example()
    elif hps_generator.mode == 'train_discriminator':
        print("Start pre-training......")
        model = Generator(hps_generator, vocab)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)

        # util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")

        model_dis = Discriminator(hps_discriminator, vocab)
        dis_batcher = DisBatcher(hps_discriminator, vocab,
                                 "discriminator_train/positive/*",
                                 "discriminator_train/negative/*",
                                 "discriminator_test/positive/*",
                                 "discriminator_test/negative/*")
        sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
            model_dis)
        print("Start pre-training discriminator......")
        # run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test")
        if not os.path.exists("discriminator_result"):
            os.mkdir("discriminator_result")
        run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis,
                                    saver_dis, train_dir_dis)
Beispiel #11
0
def run_eval_pred(model, batcher, vocab):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph()  # build the graph
    saver = tf.train.Saver(
        max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())
    eval_dir = os.path.join(
        FLAGS.log_root, "eval")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)
    running_avg_loss = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping

    best_loss = None  # will hold the best loss achieved so far

    while True:
        _ = util.load_ckpt(saver, sess)  # load a new checkpoint
        batch = batcher.next_batch()  # get the next batch
        while np.sum(batch.rst_batch) == FLAGS.batch_size or np.sum(
                batch.rst_batch) == 0:
            batch = batcher.next_batch()
        # run eval on the batch
        t0 = time.time()
        results = model.run_predict_eval_step(sess, batch)
        t1 = time.time()
        tf.logging.info('seconds for batch: %.2f', t1 - t0)

        # print the loss and coverage loss to screen
        loss = results['pred_loss']
        tf.logging.info('loss: %f', loss)

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']

        summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        running_avg_loss = calc_running_avg_loss(np.asscalar(loss),
                                                 running_avg_loss,
                                                 summary_writer, train_step)

        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        if best_loss is None or running_avg_loss < best_loss:
            tf.logging.info(
                'Found new best model with %.3f running_avg_loss. Saving to %s',
                running_avg_loss, bestmodel_save_path)
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')
            best_loss = running_avg_loss
        # y_true = np.array(results['rst'])
        # y_pred = np.array(results['pred'])
        #
        # print(sklearn.metrics.classification_report(y_true, y_pred))
        # print(confusion_matrix(y_true, y_pred))

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()
def run_eval(model, batcher):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph()  # build the graph
    saver = tf.train.Saver(
        max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())
    eval_dir = os.path.join(
        FLAGS.log_root,
        "eval_loss")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)
    running_avg_loss = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = None  # will hold the best loss achieved so far
    train_dir = os.path.join(FLAGS.log_root, "train")

    while True:
        ckpt_state = tf.train.get_checkpoint_state(train_dir)
        tf.logging.info('max_enc_steps: %d, max_dec_steps: %d',
                        FLAGS.max_enc_steps, FLAGS.max_dec_steps)
        _ = util.load_ckpt(saver, sess)  # load a new checkpoint
        batch = batcher.next_batch()  # get the next batch

        # run eval on the batch
        t0 = time.time()
        results = model.run_eval_step(sess, batch)
        t1 = time.time()
        tf.logging.info('seconds for batch: %.2f', t1 - t0)

        # print the loss and coverage loss to screen
        loss = results['loss']
        tf.logging.info('loss: %f', loss)
        train_step = results['global_step']

        tf.logging.info("pgen_avg: %f", results['p_gen_avg'])

        if FLAGS.coverage:
            tf.logging.info("coverage_loss: %f", results['coverage_loss'])

        # add summaries
        summaries = results['summaries']
        summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        running_avg_loss = util.calc_running_avg_loss(np.asscalar(loss),
                                                      running_avg_loss,
                                                      summary_writer,
                                                      train_step,
                                                      'running_avg_loss')

        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        if best_loss is None or running_avg_loss < best_loss:
            tf.logging.info(
                'Found new best model with %.3f running_avg_loss. Saving to %s',
                running_avg_loss, bestmodel_save_path)
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')
            best_loss = running_avg_loss

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()
Beispiel #13
0
def run_training(model, batcher, sess_context_manager, sv, summary_writer, word_vector, \
                 selector_saver=None, rewriter_saver=None, all_saver=None):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    tf.logging.info("starting run_training")
    train_step = 0
    ckpt_path = os.path.join(FLAGS.log_root, "train", "model.ckpt_cov")

    with sess_context_manager as sess:
        if FLAGS.pretrained_selector_path:  # Load the pre-trained model
            tf.logging.info('Loading selector model')
            _ = util.load_ckpt(selector_saver,
                               sess,
                               ckpt_path=FLAGS.pretrained_selector_path)
        if FLAGS.pretrained_rewriter_path:
            tf.logging.info('Loading rewriter model')
            _ = util.load_ckpt(rewriter_saver,
                               sess,
                               ckpt_path=FLAGS.pretrained_rewriter_path)

        for _ in range(FLAGS.max_train_iter):  # repeats until interrupted
            batch = batcher.next_batch()

            tf.logging.info('running training step...')
            t0 = time.time()
            results = model.run_train_step(sess, batch, word_vector,
                                           train_step)
            t1 = time.time()
            tf.logging.info('seconds for training step: %.3f', t1 - t0)

            loss = results['loss']
            tf.logging.info('rl_loss: %f', loss)  # print the loss to screen
            train_step = results['global_step']

            if not np.isfinite(loss):
                raise Exception("Loss is not finite. Stopping.")

            tf.logging.info("reinforce_avg_logprobs: %f",
                            results['reinforce_avg_logprobs'])

            if FLAGS.coverage:
                tf.logging.info("coverage_loss: %f", results['coverage_loss']
                                )  # print the coverage loss to screen
                tf.logging.info("reinforce_coverage_loss: %f",
                                results['reinforce_coverage_loss'])

            if FLAGS.inconsistent_loss:
                tf.logging.info('inconsistent_loss: %f',
                                results['inconsist_loss'])

            tf.logging.info("selector_loss: %f", results['selector_loss'])
            recall, ratio, _ = util.get_batch_ratio(
                batch.original_articles_sents, batch.original_extracts_ids,
                results['probs'])
            write_to_summary(ratio, 'SentSelector/select_ratio/recall=0.9',
                             train_step, summary_writer)

            # get the summaries and iteration number so we can write summaries to tensorboard
            summaries = results[
                'summaries']  # we will write these summaries to tensorboard using summary_writer
            summary_writer.add_summary(summaries,
                                       train_step)  # write the summaries
            if train_step % 100 == 0:  # flush the summary writer every so often
                summary_writer.flush()

            if train_step % FLAGS.save_model_every == 0:
                if FLAGS.pretrained_selector_path and FLAGS.pretrained_rewriter_path:
                    all_saver.save(sess, ckpt_path, global_step=train_step)
                else:
                    sv.saver.save(sess, ckpt_path, global_step=train_step)

            print('Step: ', train_step)
def run_decode(model, batcher, vocab):
    print "build graph..."
    model.build_graph()
    saver = tf.train.Saver(max_to_keep=3)
    sess = tf.Session(config=util.get_config())
    saver = tf.train.Saver()
    ckpt_path = util.load_ckpt(saver, sess)
    if FLAGS.single_pass:
        ckpt_name = "ckpt-" + ckpt_path.split('-')[-1]
        dirname = "decode_maxenc_%ibeam_%imindec_%imaxdec_%i" % (FLAGS.max_enc_steps, FLAGS.beam_size, FLAGS.min_dec_steps, FLAGS.max_dec_steps)
        decode_dir = os.path.join(FLAGS.log_root, dirname + ckpt_name)
        if os.path.exists(decode_dir):
            raise Exception('single_pass decode directory %s should not exist', decode_dir)
    else:
        decode_dir = os.path.join(FLAGS.log_root, 'decode')
    if not os.path.exists(decode_dir): os.mkdir(decode_dir)
    if FLAGS.single_pass:
      rouge_ref_dir = os.path.join(decode_dir, "reference")
      if not os.path.exists(rouge_ref_dir): os.mkdir(rouge_ref_dir)
      rouge_dec_dir = os.path.join(decode_dir, "decoded")
      if not os.path.exists(rouge_dec_dir): os.mkdir(rouge_dec_dir)
    counter = 0
    t0 = time.time()
    while True:
        batch = batcher.next_batch()  # 1 example repeated across batch
        if batch is None:  # finished decoding dataset in single_pass mode
            assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
            print("Decoder has finished reading dataset for single_pass.")
            print("Output has been saved in %s and %s. Now starting ROUGE eval...", rouge_ref_dir,
                            rouge_dec_dir)
            results_dict = rouge_eval(rouge_ref_dir, rouge_dec_dir)
            rouge_log(results_dict, decode_dir)
            return

        original_article = batch.original_articles[0]  # string
        original_abstract = batch.original_abstracts[0]  # string
        original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

        article_withunks = data.show_art_oovs(original_article, vocab)  # string
        abstract_withunks = data.show_abs_oovs(original_abstract, vocab, None)  # string

        # Run beam search to get best Hypothesis
        output = model.run_beam_decode_step(sess, batch, vocab)
        output_ids = [int(t) for t in output]
        decoded_words = data.outputids2words(output_ids, vocab, None)

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_output = ' '.join(decoded_words)  # single string

        if FLAGS.single_pass:
            write_for_rouge(original_abstract_sents, decoded_words, counter, rouge_ref_dir, rouge_dec_dir)  # write ref summary and decoded summary to file, to eval with pyrouge later
            counter += 1  # this is how many examples we've decoded
        else:
            print_results(article_withunks, abstract_withunks, decoded_output)  # log output to screen
            # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
            t1 = time.time()
            if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                                t1 - t0)
                _ = util.load_ckpt(saver, sess)
                t0 = time.time()
Beispiel #15
0
def run_eval(model, batcher, word_vector):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph()  # build the graph
    saver = tf.train.Saver(
        max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())
    if FLAGS.embedding:
        sess.run(tf.global_variables_initializer()
                 )  #, feed_dict={model.embedding_place: word_vector}
    eval_dir = os.path.join(
        FLAGS.log_root, "eval")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)  # 指定一个文件用来保存图

    running_avg_ratio = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_ratio = None  # will hold the best loss achieved so far
    train_dir = os.path.join(FLAGS.log_root, "train")

    while True:
        ckpt_state = tf.train.get_checkpoint_state(
            train_dir)  # get the info of checkpoint file

        #tf.logging.info('max_enc_steps: %d, max_dec_steps: %d', FLAGS.max_enc_steps, FLAGS.max_dec_steps)
        _ = util.load_ckpt(saver, sess)  # load a new checkpoint
        batch = batcher.next_batch()  # get the next batch

        # run eval on the batch
        t0 = time.time()
        results = model.run_eval_step(sess, batch)
        t1 = time.time()
        tf.logging.info('seconds for batch: %.2f', t1 - t0)

        # print the loss and coverage loss to screen
        loss = results['loss']
        tf.logging.info('loss: %f', loss)
        train_step = results['global_step']

        recall, ratio, _ = util.get_batch_ratio(batch.original_articles_sents, \
                                                batch.original_extracts_ids, results['probs'])
        write_to_summary(ratio, 'SentSelector/select_ratio/recall=0.9',
                         train_step, summary_writer)

        # add summaries
        summaries = results['summaries']
        summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        running_avg_ratio = util.calc_running_avg_loss(ratio,
                                                       running_avg_ratio,
                                                       summary_writer,
                                                       train_step,
                                                       'running_avg_ratio')
        print("run_avg_ratio: ", running_avg_ratio)
        tf.log("run_avg_ratio: ", running_avg_ratio)
        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        if best_ratio is None or running_avg_ratio < best_ratio:
            tf.logging.info(
                'Found new best model with %.3f running_avg_ratio. Saving to %s',
                running_avg_ratio, bestmodel_save_path)
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')
            best_ratio = running_avg_ratio

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()
Beispiel #16
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    tf.logging.set_verbosity(
        tf.logging.INFO)  # choose what level of logging you want
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    if not os.path.exists(FLAGS.log_root):
        if FLAGS.mode == "train":
            os.makedirs(FLAGS.log_root)
        else:
            raise Exception(
                "Logdir %s doesn't exist. Run in train mode to create it." %
                (FLAGS.log_root))

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_steps', 'max_enc_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    tf.set_random_seed(
        111
    )  # a seed value for randomness # train-classification  train-sentiment  train-cnn-classificatin train-generator

    if FLAGS.mode == "train-classifier":

        #print("Start pre-training......")
        model_class = Classification(hps_discriminator, vocab)
        cla_batcher = ClaBatcher(hps_discriminator, vocab)
        sess_cls, saver_cls, train_dir_cls = setup_training_classification(
            model_class)
        print("Start pre-training classification......")
        run_pre_train_classification(model_class, cla_batcher, 1, sess_cls,
                                     saver_cls, train_dir_cls)  #10
        generated = Generate_training_sample(model_class, vocab, cla_batcher,
                                             sess_cls)

        print("Generating training examples......")
        generated.generate_training_example("train")
        generated.generate_test_example("test")

    elif FLAGS.mode == "train-sentimentor":

        model_class = Classification(hps_discriminator, vocab)
        cla_batcher = ClaBatcher(hps_discriminator, vocab)
        sess_cls, saver_cls, train_dir_cls = setup_training_classification(
            model_class)

        print("Start pre_train_sentimentor......")
        model_sentiment = Sentimentor(hps_generator, vocab)
        sentiment_batcher = SenBatcher(hps_generator, vocab)
        sess_sen, saver_sen, train_dir_sen = setup_training_sentimentor(
            model_sentiment)
        util.load_ckpt(saver_cls, sess_cls, ckpt_dir="train-classification")
        run_pre_train_sentimentor(model_sentiment, sentiment_batcher, 1,
                                  sess_sen, saver_sen, train_dir_sen)  #1

    elif FLAGS.mode == "test":

        config = {
            'n_epochs': 5,
            'kernel_sizes': [3, 4, 5],
            'dropout_rate': 0.5,
            'val_split': 0.4,
            'edim': 300,
            'n_words': None,  # Leave as none
            'std_dev': 0.05,
            'sentence_len': 50,
            'n_filters': 100,
            'batch_size': 50
        }
        config['n_words'] = 50000

        cla_cnn_batcher = CNN_ClaBatcher(hps_discriminator, vocab)
        cnn_classifier = CNN(config)
        sess_cnn_cls, saver_cnn_cls, train_dir_cnn_cls = setup_training_cnnclassifier(
            cnn_classifier)
        #util.load_ckpt(saver_cnn_cls, sess_cnn_cls, ckpt_dir="train-cnnclassification")
        run_train_cnn_classifier(cnn_classifier, cla_cnn_batcher, 1,
                                 sess_cnn_cls, saver_cnn_cls,
                                 train_dir_cnn_cls)  #1

        files = os.listdir("test-generate-transfer/")
        for file_ in files:
            run_test_our_method(cla_cnn_batcher, cnn_classifier, sess_cnn_cls,
                                "test-generate-transfer/" + file_ + "/*")

    #elif FLAGS.mode == "test":

    elif FLAGS.mode == "train-generator":

        model_class = Classification(hps_discriminator, vocab)
        cla_batcher = ClaBatcher(hps_discriminator, vocab)
        sess_cls, saver_cls, train_dir_cls = setup_training_classification(
            model_class)

        model_sentiment = Sentimentor(hps_generator, vocab)
        sentiment_batcher = SenBatcher(hps_generator, vocab)
        sess_sen, saver_sen, train_dir_sen = setup_training_sentimentor(
            model_sentiment)

        config = {
            'n_epochs': 5,
            'kernel_sizes': [3, 4, 5],
            'dropout_rate': 0.5,
            'val_split': 0.4,
            'edim': 300,
            'n_words': None,  # Leave as none
            'std_dev': 0.05,
            'sentence_len': 50,
            'n_filters': 100,
            'batch_size': 50
        }
        config['n_words'] = 50000

        cla_cnn_batcher = CNN_ClaBatcher(hps_discriminator, vocab)
        cnn_classifier = CNN(config)
        sess_cnn_cls, saver_cnn_cls, train_dir_cnn_cls = setup_training_cnnclassifier(
            cnn_classifier)

        model = Generator(hps_generator, vocab)
        batcher = GenBatcher(vocab, hps_generator)
        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)

        util.load_ckpt(saver_cnn_cls,
                       sess_cnn_cls,
                       ckpt_dir="train-cnnclassification")
        util.load_ckpt(saver_sen, sess_sen, ckpt_dir="train-sentimentor")

        generated = Generated_sample(model, vocab, batcher, sess_ge)
        print("Start pre-training generator......")
        run_pre_train_generator(model, batcher, 1, sess_ge, saver_ge,
                                train_dir_ge, generated, cla_cnn_batcher,
                                cnn_classifier, sess_cnn_cls)  # 4

        generated.generate_test_negetive_example(
            "temp_negetive",
            batcher)  # batcher, model_class, sess_cls, cla_batcher
        generated.generate_test_positive_example("temp_positive", batcher)

        #run_test_our_method(cla_cnn_batcher, cnn_classifier, sess_cnn_cls,
        #                    "temp_negetive" + "/*")

        loss_window = 0
        t0 = time.time()
        print("begin reinforcement learning:")
        for epoch in range(30):
            batches = batcher.get_batches(mode='train')
            for i in range(len(batches)):
                current_batch = copy.deepcopy(batches[i])
                sentiment_batch = batch_sentiment_batch(
                    current_batch, sentiment_batcher)
                result = model_sentiment.max_generator(sess_sen,
                                                       sentiment_batch)
                weight = result['generated']
                current_batch.weight = weight
                sentiment_batch.weight = weight

                cla_batch = batch_classification_batch(current_batch, batcher,
                                                       cla_batcher)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)

                cc = SmoothingFunction()

                reward_sentiment = 1 - np.abs(0.5 - result['y_pred_auc'])
                reward_BLEU = []
                for k in range(FLAGS.batch_size):
                    reward_BLEU.append(
                        sentence_bleu(
                            [current_batch.original_reviews[k].split()],
                            cla_batch.original_reviews[k].split(),
                            smoothing_function=cc.method1))

                reward_BLEU = np.array(reward_BLEU)

                reward_de = (2 / (1.0 / (1e-6 + reward_sentiment) + 1.0 /
                                  (1e-6 + reward_BLEU)))

                result = model.run_train_step(sess_ge, current_batch)
                train_step = result[
                    'global_step']  # we need this to update our running average loss
                loss = result['loss']
                loss_window += loss
                if train_step % 100 == 0:
                    t1 = time.time()
                    tf.logging.info(
                        'seconds for %d training generator step: %.3f ',
                        train_step, (t1 - t0) / 100)
                    t0 = time.time()
                    tf.logging.info('loss: %f', loss_window /
                                    100)  # print the loss to screen
                    loss_window = 0.0
                if train_step % 10000 == 0:

                    generated.generate_test_negetive_example(
                        "test-generate-transfer/" + str(epoch) + "epoch_step" +
                        str(train_step) + "_temp_positive", batcher)
                    generated.generate_test_positive_example(
                        "test-generate/" + str(epoch) + "epoch_step" +
                        str(train_step) + "_temp_positive", batcher)
                    #saver_ge.save(sess, train_dir + "/model", global_step=train_step)
                    #run_test_our_method(cla_cnn_batcher, cnn_classifier, sess_cnn_cls,
                    #                    "test-generate-transfer/" + str(epoch) + "epoch_step" + str(
                    #                        train_step) + "_temp_positive" + "/*")

                cla_batch, bleu = output_to_classification_batch(
                    result['generated'], current_batch, batcher, cla_batcher,
                    cc)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)
                reward_result_sentiment = result['y_pred_auc']
                reward_result_bleu = np.array(bleu)

                reward_result = (2 / (1.0 /
                                      (1e-6 + reward_result_sentiment) + 1.0 /
                                      (1e-6 + reward_result_bleu)))

                current_batch.score = 1 - current_batch.score

                result = model.max_generator(sess_ge, current_batch)

                cla_batch, bleu = output_to_classification_batch(
                    result['generated'], current_batch, batcher, cla_batcher,
                    cc)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)
                reward_result_transfer_sentiment = result['y_pred_auc']
                reward_result_transfer_bleu = np.array(bleu)

                reward_result_transfer = (
                    2 / (1.0 /
                         (1e-6 + reward_result_transfer_sentiment) + 1.0 /
                         (1e-6 + reward_result_transfer_bleu)))

                #tf.logging.info("reward_nonsentiment: "+str(reward_sentiment) +" output_original_sentiment: "+str(reward_result_sentiment)+" output_original_bleu: "+str(reward_result_bleu))

                reward = reward_result_transfer  #reward_de + reward_result_sentiment +
                #tf.logging.info("reward_de: "+str(reward_de))

                model_sentiment.run_train_step(sess_sen, sentiment_batch,
                                               reward)
Beispiel #17
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely,
        loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        if FLAGS.decode_bleu:
            ref_file = os.path.join(self._bleu_dec_dir, "reference.txt")
            decoded_file = os.path.join(self._bleu_dec_dir, "decoded.txt")
            if os.path.exists(decoded_file):
                tf.logging.info('正在删除 %s', decoded_file)
                os.remove(decoded_file)
            if os.path.exists(ref_file):
                tf.logging.info('正在删除 %s', ref_file)
                os.remove(ref_file)
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                if FLAGS.decode_rouge:
                    tf.logging.info(
                        "Output has been saved in %s and %s. Now starting ROUGE eval...",
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    try:
                        t0 = time.time()
                        results_dict = rouge_eval(self._rouge_ref_dir,
                                                  self._rouge_dec_dir)
                        rouge_log(results_dict, self._decode_dir)
                        t1 = time.time()
                        tf.logging.info(
                            'calculate Rouge score cost %d seconds', t1 - t0)
                    except Exception as e:
                        tf.logging.error('计算ROUGE出错 %s', e)
                if FLAGS.decode_bleu:
                    ref_file = os.path.join(self._bleu_dec_dir,
                                            "reference.txt")
                    decoded_file = os.path.join(self._bleu_dec_dir,
                                                "decoded.txt")

                    t0 = time.time()
                    bleu, bleu1, bleu2, bleu3, bleu4 = calcu_bleu(
                        decoded_file, ref_file)
                    sys_bleu = sys_bleu_file(decoded_file, ref_file)
                    sys_bleu_perl = sys_bleu_perl_file(decoded_file, ref_file)
                    t1 = time.time()

                    tf.logging.info(bcolors.HEADER +
                                    '-----------BLEU SCORE-----------' +
                                    bcolors.ENDC)
                    tf.logging.info(
                        bcolors.OKGREEN + '%f \t %f \t %f \t %f \t %f' +
                        bcolors.ENDC, bleu, bleu1, bleu2, bleu3, bleu4)
                    tf.logging.info(
                        bcolors.OKGREEN + 'sys_bleu %f' + bcolors.ENDC,
                        sys_bleu)
                    tf.logging.info(
                        bcolors.OKGREEN + 'sys_bleu_perl %s' + bcolors.ENDC,
                        sys_bleu_perl)
                    tf.logging.info(bcolors.HEADER +
                                    '-----------BLEU SCORE-----------' +
                                    bcolors.ENDC)
                    tf.logging.info('calculate BLEU score cost %d seconds',
                                    t1 - t0)
                break

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ''.join(decoded_words)  # single string

            if FLAGS.single_pass:
                print_results(article_withunks, abstract_withunks,
                              decoded_output, counter)  # log output to screen
                if FLAGS.decode_rouge:
                    self.write_for_rouge(
                        original_abstract_sents, decoded_words, counter
                    )  # write ref summary and decoded summary to file, to eval with pyrouge later
                if FLAGS.decode_bleu:
                    self.write_for_bleu(original_abstract_sents, decoded_words)
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens
                )  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
    def decode(self, output_dir=None):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        idx = 0

        # used to store values during decoding in a list each
        outut_str = ""
        beam_search_str = ""
        metadata = []

        # evaluate over a fixed number of test set
        while True:  #idx <=100 :

            print("[%d]" % idx)

            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch

            #      if idx < 11000:
            #          idx += 1
            #          continue

            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                results_dict = rouge_eval(self._rouge_ref_dir,
                                          self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            #  Run beam search to get all the Hypothesis
            all_hyp = beam_search.run_beam_search(
                self._sess, self._model, self._vocab, batch, counter,
                self._lm_model, self._lm_word2idx, self._lm_idx2word
            )  #TODO changed the method signature just to look at the outputs of beam search

            if FLAGS.save_values:
                for h in all_hyp:
                    output_ids = [int(t) for t in h.tokens[1:]]
                    search_str = str(
                        data.outputids2words(output_ids, self._vocab,
                                             (batch.art_oovs[0]
                                              if FLAGS.pointer_gen else None)))
                    beam_search_str += search_str
                beam_search_str += "\n"

            # Extract the get best Hypothesis
            best_hyp = all_hyp[0]

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))
            metadata.append(decoded_words)

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            ###########################
            #print best hyp statistics

            hyp_stat = ""

            # log prob
            hyp_stat += "\navg log prob: %s.\n" % best_hyp.avg_log_prob

            # words overlap with article: this is buggy
            #      tri, bi, uni = word_overlap.gram_search(ngrams(nltk.pos_tag(article_withunks.strip().split()), 3), ngrams(nltk.pos_tag(decoded_output.strip().split()), 3))
            #      hyp_stat += "trigram overlap: %s. bigram overlap: %s. unigram overlap: %s.\n"%(uni, bi, tri)

            print_statistics.get_overlap(article_withunks.strip(),
                                         decoded_output.strip(),
                                         match_count=self.overlap_dict)
            hyp_stat += "word overlap: "
            for key, value in self.overlap_dict.iteritems():
                hyp_stat += "\n%d-gram avg overlap: %d" % (key, value /
                                                           (counter + 1))

            # num sentences and avg length
            self.total_nsentence += len(decoded_output.strip().split(
                "."))  #sentences are seperated by "."
            self.total_length += len(decoded_output.strip().split())
            avg_nsentence, avg_length = self.total_nsentence / (
                counter + 1), self.total_length / (counter + 1)

            hyp_stat += "\nnum sentences: %s. avg len: %s.\n" % (avg_nsentence,
                                                                 avg_length)

            # entropy??
            if FLAGS.print_info:
                print(hyp_stat)
            ###########################

            # saves data into numpy files for analysis
            if FLAGS.save_values:
                save_decode_data.save_data_iteration(self._decode_dir, counter,
                                                     best_hyp)

            if FLAGS.single_pass:  #change to counter later
                self.write_for_rouge(
                    original_abstract_sents, decoded_words, counter
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                # writing all the output combined to a file
                if FLAGS.print_info:
                    output = '\nARTICLE:  %s\n REFERENCE SUMMARY: %s\n' 'GENERATED SUMMARY: %s\n' % (
                        article_withunks, abstract_withunks, decoded_output)
                    print(output)
                    outut_str += output
                # Leena: modifying this to save more stuff
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens, counter,
                    best_hyp.log_prob, best_hyp.avg_log_prob,
                    best_hyp.average_pgen)  #change to counter later
                counter += 1  # this is how many examples we've decoded

            else:  #Leena: I use the above condition so might have neglected making change to the below condition
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens,
                    counter)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()

            idx += 1

        #Leena: saving entire output and beam output as a string to write to a file
        if FLAGS.save_values:
            save_decode_data.save_data_once(self._decode_dir,
                                            FLAGS.result_path, outut_str,
                                            beam_search_str, metadata)
Beispiel #19
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        if not FLAGS.generate:
            t0 = time.time()
            counter = 0
            while True:
                batch = self._batcher.next_batch(
                )  # 1 example repeated across batch
                if batch is None:  # finished decoding dataset in single_pass mode
                    assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                    tf.logging.info(
                        "Decoder has finished reading dataset for single_pass."
                    )
                    tf.logging.info(
                        "Output has been saved in %s and %s. Now starting ROUGE eval...",
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    results_dict = rouge_eval(self._rouge_ref_dir,
                                              self._rouge_dec_dir)
                    rouge_log(results_dict, self._decode_dir)
                    return

                original_article = batch.original_articles[0]  # string
                original_abstract = batch.original_abstracts[0]  # string
                original_abstract_sents = batch.original_abstracts_sents[
                    0]  # list of strings

                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                abstract_withunks = data.show_abs_oovs(
                    original_abstract, self._vocab,
                    (batch.art_oovs[0]
                     if FLAGS.pointer_gen else None))  # string

                # Run beam search to get best Hypothesis
                best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                       self._vocab, batch)

                # Extract the output ids from the hypothesis and convert back to words
                output_ids = [int(t) for t in best_hyp.tokens[1:]]
                decoded_words = data.outputids2words(
                    output_ids, self._vocab,
                    (batch.art_oovs[0] if FLAGS.pointer_gen else None))

                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words
                decoded_output = ' '.join(decoded_words)  # single string

                if FLAGS.single_pass:
                    # write ref summary and decoded summary to file, to eval with pyrouge later
                    self.write_for_rouge(original_abstract_sents,
                                         decoded_words, counter)
                    counter += 1  # this is how many examples we've decoded
                else:
                    print_results(article_withunks, abstract_withunks,
                                  decoded_output)  # log output to screen
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens
                    )  # write info to .json file for visualization tool

                    # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                    t1 = time.time()
                    if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                        tf.logging.info(
                            'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                            t1 - t0)
                        _ = util.load_ckpt(self._saver, self._sess)
                        t0 = time.time()
        # when generate=True
        else:
            counter = 0
            while True:
                batch = self._batcher.next_batch(
                )  # 1 example repeated across batch
                if batch is None:  # finished decoding dataset in single_pass mode
                    assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                    tf.logging.info(
                        "Decoder has finished reading dataset for single_pass."
                    )
                    return

                original_article = batch.original_articles[0]  # string
                # original_abstract = batch.original_abstracts[0]  # string
                # original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

                article_withunks = data.show_art_oovs(original_article,
                                                      self._vocab)  # string
                # abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

                # Run beam search to get best Hypothesis
                best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                       self._vocab, batch)

                # Extract the output ids from the hypothesis and convert back to words
                output_ids = [int(t) for t in best_hyp.tokens[1:]]
                decoded_words = data.outputids2words(
                    output_ids, self._vocab,
                    (batch.art_oovs[0] if FLAGS.pointer_gen else None))

                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words
                decoded_output = ' '.join(decoded_words)  # single string

                counter += 1
                # log output to screen
                print(
                    "---------------------------------------------------------------------------"
                )
                tf.logging.info('ARTICLE:  %s', article_withunks)
                tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
                print(
                    "---------------------------------------------------------------------------"
                )

                # self.write_for_rouge(original_abstract_sents, decoded_words, counter)
                # Write to file
                decoded_sents = []
                while len(decoded_words) > 0:
                    try:
                        fst_period_idx = decoded_words.index(".")
                    except ValueError:  # there is text remaining that doesn't end in "."
                        fst_period_idx = len(decoded_words)
                    sent = decoded_words[:fst_period_idx +
                                         1]  # sentence up to and including the period
                    decoded_words = decoded_words[fst_period_idx +
                                                  1:]  # everything else
                    decoded_sents.append(' '.join(sent))

                # pyrouge calls a perl script that puts the data into HTML files.
                # Therefore we need to make our output HTML safe.
                decoded_sents = [make_html_safe(w) for w in decoded_sents]

                # Write to file
                result_file = os.path.join(self._result_dir,
                                           "%06d_summary.txt" % counter)

                with open(result_file, "w") as f:
                    for idx, sent in enumerate(decoded_sents):
                        f.write(sent) if idx == len(
                            decoded_sents) - 1 else f.write(sent + "\n")
def main(config, resume, phase):
    # Dataset
    fine_dataset = fine_clustering_dataset(config)
    # Dataloder
    train_loader = DataLoader(fine_dataset,
                              shuffle=True,
                              batch_size=config['batch_size'],
                              num_workers=32)
    val_loader = DataLoader(fine_dataset,
                            shuffle=False,
                            batch_size=config['batch_size'],
                            num_workers=32)
    test_loader = DataLoader(fine_dataset,
                             shuffle=False,
                             batch_size=config['batch_size'],
                             num_workers=32)
    # Model
    start_epoch = 0
    if config['model_name'].startswith('resnet'):
        model = ResNet(config)
    elif config['model_name'].startswith('densenet'):
        model = DenseNet(config)
    elif config['model_name'].startswith('deeplab'):
        cluster_vector_dim = config['cluster_vector_dim']
        model = DeepLabv3_plus(nInputChannels=3,
                               n_classes=3,
                               os=16,
                               cluster_vector_dim=cluster_vector_dim,
                               pretrained=True,
                               _print=True)
    elif config['model_name'].startswith('bagnet'):
        model = BagNet(config=config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if resume:
        filepath = config['pretrain_path']
        start_epoch, learning_rate, optimizer, M, s = load_ckpt(
            model, filepath)
        start_epoch += 1
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    model.to(device)
    #Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config['learning_rate'],
                                 weight_decay=1e-5)
    #resume or not
    if start_epoch == 0:
        print("Grand New Training")
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=config['switch_learning_rate_interval'])
    # log_dir = config['log_dir']+"/{}_{}_".format(config['date'],config['model_name'])+"ep_{}-{}_lr_{}".format(start_epoch,start_epoch+config['num_epoch'],config['learning_rate'])
    # best loss
    if not resume:
        learning_rate = config['learning_rate']
        M, s = cluster_initialization(train_loader, model, config, phase)
    print(start_epoch)
    if config['if_train']:
        for epoch in range(start_epoch + 1,
                           start_epoch + config['num_epoch'] + 1):
            loss_tr = train(
                train_loader, model, optimizer, epoch, config, M,
                s)  #if training, delete learning rate and add optimizer
            if config['if_valid'] and epoch % config[
                    'valid_epoch_interval'] == 0:
                with torch.no_grad():
                    loss_val, M, s = valid(val_loader, model, epoch, config,
                                           learning_rate, M, s, phase)
                    scheduler.step(loss_val)
                save_ckpt(model, optimizer, epoch, loss_tr, loss_val, config,
                          M, s)
            else:
                val_log = open("../log/val_" + config['date'] + ".txt", "a")
                val_log.write('epoch ' + str(epoch) + '\n')
                val_log.close()
    test(test_loader, model, config, M, phase)
    store_config(config, phase)
    print("Training finished ...")
Beispiel #21
0
  def run_eval(self):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    self.model.build_graph() # build the graph
    saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())

    if FLAGS.embedding:
      sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector})
    eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
    self.summary_writer = tf.summary.FileWriter(eval_dir)

    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      dqn_graph = tf.Graph()
      with dqn_graph.as_default():
        self.dqn.build_graph() # build dqn graph
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))
        self.dqn_target.build_graph() # build dqn target graph
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))
        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        dqn_sess = tf.Session(config=util.get_config())
      dqn_train_step = 0
      replay_buffer = ReplayBuffer(self.dqn_hps)

    running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = self.restore_best_eval_model()  # will hold the best loss achieved so far
    train_step = 0

    while True:
      _ = util.load_ckpt(saver, sess) # load a new checkpoint
      if FLAGS.ac_training:
        _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint
      processed_batch = 0
      avg_losses = []
      # evaluate for 100 * batch_size before comparing the loss
      # we do this due to memory constraint, best to run eval on different machines with large batch size
      while processed_batch < 100*FLAGS.batch_size:
        processed_batch += FLAGS.batch_size
        batch = self.batcher.next_batch() # get the next batch
        if FLAGS.ac_training:
          t0 = time.time()
          transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
          tf.logging.info('Q values collection time: {}'.format(time.time()-t0))
          with dqn_graph.as_default():
            # if using true Q-value to train DQN network,
            # we do this as the pre-training for the DQN network to get better estimates
            batch_len = len(transitions)
            b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True)
            q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
            dqn_best_action = dqn_results['best_action']

            tf.logging.info('running test step on dqn_target')
            dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x)
            q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

            # we need to expand the q_estimates to match the input batch max_art_oov
            q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1)

            tf.logging.info('fixing the action q-estimates')
            for i, tr in enumerate(transitions):
              if tr.done:
                q_estimates[i][tr.action] = tr.reward
              else:
                q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
            if FLAGS.dqn_scheduled_sampling:
              tf.logging.info('scheduled sampling on q-estimates')
              q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
            if not FLAGS.calculate_true_q:
              # when we are not training DQN based on true Q-values
              # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network.
              for trans, q_val in zip(transitions,q_estimates):
                trans.q_values = q_val # each have the size vocab_extended
            q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step, q_estimates)
          t1=time.time()
        else:
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step)
          t1=time.time()

        tf.logging.info('experiment: {}'.format(FLAGS.exp_name))
        tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0))

        printer_helper = {}
        loss = printer_helper['pgen_loss']= results['pgen_loss']
        if FLAGS.coverage:
          printer_helper['coverage_loss'] = results['coverage_loss']
          if FLAGS.rl_training or FLAGS.ac_training:
            printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
          loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['shared_loss'] = results['shared_loss']
          printer_helper['rl_loss'] = results['rl_loss']
          printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
        if FLAGS.rl_training:
          printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values'])
          printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values'])
          printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r']
        if FLAGS.ac_training:
          printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss) > 0 else 0

        for (k,v) in printer_helper.items():
          if not np.isfinite(v):
            raise Exception("{} is not finite. Stopping.".format(k))
          tf.logging.info('{}: {}\t'.format(k,v))

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        self.summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, train_step))
        tf.logging.info('-------------------------------------------')

      running_avg_loss = np.mean(avg_losses)
      tf.logging.info('==========================================')
      tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss))
      tf.logging.info('==========================================')

      # If running_avg_loss is best so far, save this checkpoint (early stopping).
      # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
      if best_loss is None or running_avg_loss < best_loss:
        tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
        saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
        best_loss = running_avg_loss

      # flush the summary writer every so often
      if train_step % 100 == 0:
        self.summary_writer.flush()
Beispiel #22
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    tf.logging.set_verbosity(
        tf.logging.INFO)  # choose what level of logging you want
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    if not os.path.exists(FLAGS.log_root):
        if FLAGS.mode == "train":
            os.makedirs(FLAGS.log_root)
        else:
            raise Exception(
                "Logdir %s doesn't exist. Run in train mode to create it." %
                (FLAGS.log_root))

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_steps', 'max_enc_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    tf.set_random_seed(111)  # a seed value for randomness

    if hps_generator.mode == 'train':

        print("Start pre-training......")
        model_class = Classification(hps_discriminator, vocab)
        cla_batcher = ClaBatcher(hps_discriminator, vocab)
        sess_cls, saver_cls, train_dir_cls = setup_training_classification(
            model_class)
        print("Start pre-training classification......")
        #run_pre_train_classification(model_class, cla_batcher, 10, sess_cls, saver_cls, train_dir_cls)
        #generated = Generate_training_sample(model_class, vocab, cla_batcher, sess_cls)

        #print("Generating training examples......")
        #generated.generate_training_example("train")
        #generated.generator_validation_example("valid")

        model_sentiment = Sentimentor(hps_generator, vocab)
        sentiment_batcher = SenBatcher(hps_generator, vocab)
        sess_sen, saver_sen, train_dir_sen = setup_training_sentimentor(
            model_sentiment)
        #run_pre_train_sentimentor(model_sentiment,sentiment_batcher,1,sess_sen,saver_sen,train_dir_sen)
        sentiment_generated = Generate_non_sentiment_weight(
            model_sentiment, vocab, sentiment_batcher, sess_sen)
        #sentiment_generated.generate_training_example("train_sentiment")
        #sentiment_generated.generator_validation_example("valid_sentiment")

        model = Generator(hps_generator, vocab)
        # Create a batcher object that will create minibatches of data
        batcher = GenBatcher(vocab, hps_generator)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)

        util.load_ckpt(saver_sen, sess_sen, ckpt_dir="train-sentimentor")

        util.load_ckpt(saver_cls, sess_cls, ckpt_dir="train-classification")

        generated = Generated_sample(model, vocab, batcher, sess_ge)
        #print("Start pre-training generator......")
        run_pre_train_generator(
            model, batcher, 4, sess_ge, saver_ge, train_dir_ge, generated,
            model_class, sess_cls,
            cla_batcher)  # this is an infinite loop until interrupted

        #generated.generator_validation_negetive_example("temp_negetive", batcher, model_class,sess_cls,cla_batcher) # batcher, model_class, sess_cls, cla_batcher
        #generated.generator_validation_positive_example(
        #    "temp_positive", batcher, model_class,sess_cls,cla_batcher)

        loss_window = 0
        t0 = time.time()
        print("begin dual learning:")
        for epoch in range(30):
            batches = batcher.get_batches(mode='train')
            for i in range(len(batches)):
                current_batch = copy.deepcopy(batches[i])
                sentiment_batch = batch_sentiment_batch(
                    current_batch, sentiment_batcher)
                result = model_sentiment.max_generator(sess_sen,
                                                       sentiment_batch)
                weight = result['generated']
                current_batch.weight = weight
                sentiment_batch.weight = weight

                cla_batch = batch_classification_batch(current_batch, batcher,
                                                       cla_batcher)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)

                cc = SmoothingFunction()

                reward_sentiment = 1 - np.abs(0.5 - result['y_pred_auc'])
                reward_BLEU = []
                for k in range(FLAGS.batch_size):
                    reward_BLEU.append(
                        sentence_bleu(
                            [current_batch.original_reviews[k].split()],
                            cla_batch.original_reviews[k].split(),
                            smoothing_function=cc.method1))

                reward_BLEU = np.array(reward_BLEU)

                reward_de = (2 / (1.0 / (1e-6 + reward_sentiment) + 1.0 /
                                  (1e-6 + reward_BLEU)))

                result = model.run_train_step(sess_ge, current_batch)
                train_step = result[
                    'global_step']  # we need this to update our running average loss
                loss = result['loss']
                loss_window += loss
                if train_step % 100 == 0:
                    t1 = time.time()
                    tf.logging.info(
                        'seconds for %d training generator step: %.3f ',
                        train_step, (t1 - t0) / 100)
                    t0 = time.time()
                    tf.logging.info('loss: %f', loss_window /
                                    100)  # print the loss to screen
                    loss_window = 0.0
                if train_step % 10000 == 0:
                    #bleu_score = generatored.compute_BLEU(str(train_step))
                    #tf.logging.info('bleu: %f', bleu_score)  # print the loss to screen
                    generated.generator_validation_negetive_example(
                        "valid-generated-transfer/" + str(epoch) +
                        "epoch_step" + str(train_step) + "_temp_positive",
                        batcher, model_class, sess_cls, cla_batcher)
                    generated.generator_validation_positive_example(
                        "valid-generated/" + str(epoch) + "epoch_step" +
                        str(train_step) + "_temp_positive", batcher,
                        model_class, sess_cls, cla_batcher)
                    #saver_ge.save(sess, train_dir + "/model", global_step=train_step)

                cla_batch, bleu = output_to_classification_batch(
                    result['generated'], current_batch, batcher, cla_batcher,
                    cc)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)
                reward_result_sentiment = result['y_pred_auc']
                reward_result_bleu = np.array(bleu)

                reward_result = (2 / (1.0 /
                                      (1e-6 + reward_result_sentiment) + 1.0 /
                                      (1e-6 + reward_result_bleu)))

                current_batch.score = 1 - current_batch.score

                result = model.max_generator(sess_ge, current_batch)

                cla_batch, bleu = output_to_classification_batch(
                    result['generated'], current_batch, batcher, cla_batcher,
                    cc)
                result = model_class.run_ypred_auc(sess_cls, cla_batch)
                reward_result_transfer_sentiment = result['y_pred_auc']
                reward_result_transfer_bleu = np.array(bleu)

                reward_result_transfer = (
                    2 / (1.0 /
                         (1e-6 + reward_result_transfer_sentiment) + 1.0 /
                         (1e-6 + reward_result_transfer_bleu)))

                #tf.logging.info("reward_nonsentiment: "+str(reward_sentiment) +" output_original_sentiment: "+str(reward_result_sentiment)+" output_original_bleu: "+str(reward_result_bleu))

                reward = reward_result_transfer  #reward_de + reward_result_sentiment +
                #tf.logging.info("reward_de: "+str(reward_de))

                model_sentiment.run_train_step(sess_sen, sentiment_batch,
                                               reward)

    elif hps_generator.mode == 'decode':
        decode_model_hps = hps_generator  # This will be the hyperparameters for the decoder model
        #model = Generator(decode_model_hps, vocab)
        #generated = Generated_sample(model, vocab, batcher)
        #bleu_score = generated.compute_BLEU()
        #tf.logging.info('bleu: %f', bleu_score)  # print the loss to screen

    else:
        raise ValueError("The 'mode' flag must be one of train/eval/decode")
Beispiel #23
0
def run_eval(model, batcher):

    model.build_graph()  # build the graph
    saver = tf.train.Saver(
        max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())
    eval_dir = os.path.join(
        FLAGS.exp_name, "val")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)
    running_avg_loss = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = None  # will hold the best loss achieved so far
    batches = batcher.getBatches()

    while True:

        _ = util.load_ckpt(saver, sess)  # load a new checkpoint

        epoch_avg_loss = 0.
        epoch_train_steps = 0

        for batch in batches:

            # run eval on the batch
            t0 = time.time()

            results = model.run_eval_step(sess, batch)

            t1 = time.time()
            #tf.logging.info('seconds for batch: %.2f', t1-t0)

            # print the loss and coverage loss to screen
            loss = results['loss']
            epoch_train_steps += 1
            #tf.logging.info('loss: %f', loss)

            # add summodemaries
            summaries = results['summaries']
            train_step = results['global_step']
            summary_writer.add_summary(summaries, train_step)

            epoch_avg_loss = (epoch_avg_loss * (epoch_train_steps - 1.) +
                              loss) / epoch_train_steps

        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        print("Average loss for Epoch %f" % (epoch_avg_loss))
        if best_loss is None or epoch_avg_loss < best_loss:
            #tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
            print('Found new best model with %f epoch_avg_loss. Saving to %s' %
                  (epoch_avg_loss, bestmodel_save_path))
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')
            best_loss = epoch_avg_loss

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()

        if (FLAGS.dataset == 'timeseries'):
            batcher.getData(updateData=True)
            batches = batcher.getBatches()
Beispiel #24
0
  def run_eval(self):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    self.model.build_graph() # build the graph
    saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())

    if FLAGS.embedding:
      sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector})
    #eval_dir = os.path.join(FLAGS.log_root, "eval") # make a subdir of the root dir for eval data
    eval_dir = os.path.join(FLAGS.log_root, "eval_{}".format(
      "rouge" if FLAGS.rouge_based_eval else "loss"))  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') # this is where checkpoints of best models are saved
    self.summary_writer = tf.summary.FileWriter(eval_dir)

    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      dqn_graph = tf.Graph()
      with dqn_graph.as_default():
        self.dqn.build_graph() # build dqn graph
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))
        self.dqn_target.build_graph() # build dqn target graph
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))
        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        dqn_sess = tf.Session(config=util.get_config())
      dqn_train_step = 0
      replay_buffer = ReplayBuffer(self.dqn_hps)

    running_avg_loss = 0 # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = self.restore_best_eval_model()  # will hold the best loss achieved so far
    train_step = 0
    decay = 0.99

    while True:
      _ = util.load_ckpt(saver, sess) # load a new checkpoint
      if FLAGS.ac_training:
        _ = util.load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint
      processed_batch = 0
      avg_losses = []
      greedy_rouges = []
      sampled_rouges = []
      # evaluate for 100 * batch_size before comparing the loss
      # we do this due to memory constraint, best to run eval on different machines with large batch size

      while processed_batch < FLAGS.eval_interval*FLAGS.batch_size:
        processed_batch += FLAGS.batch_size
        batch = full_batch = self.full_batcher.next_batch()
        if FLAGS.rl_training:
          partial_batch = self.partial_batcher.next_batch()
          batch = Batcher.merge_batches(full_batch, partial_batch)
        if batch.is_any_null():
          print(partial_batch.original_abstracts_sents)
          print(full_batch.original_abstracts_sents)
          import ipdb
          ipdb.set_trace()
          print(np.concatenate((full_batch.original_abstracts_sents, partial_batch.original_abstracts_sents), axis=0))
          raise Exception
        else:
          partial_batch = None
        if FLAGS.ac_training:
          t0 = time.time()
          transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
          tf.logging.info('Q values collection time: {}'.format(time.time()-t0))
          with dqn_graph.as_default():
            # if using true Q-value to train DQN network,
            # we do this as the pre-training for the DQN network to get better estimates
            batch_len = len(transitions)
            b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True)
            q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
            dqn_best_action = dqn_results['best_action']

            tf.logging.info('running test step on dqn_target')
            dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x)
            q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

            # we need to expand the q_estimates to match the input batch max_art_oov
            q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1)

            tf.logging.info('fixing the action q-estimates')
            for i, tr in enumerate(transitions):
              if tr.done:
                q_estimates[i][tr.action] = tr.reward
              else:
                q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
            if FLAGS.dqn_scheduled_sampling:
              tf.logging.info('scheduled sampling on q-estimates')
              q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
            if not FLAGS.calculate_true_q:
              # when we are not training DQN based on true Q-values
              # we need to update Q-values in our transitions based on this q_estimates we collected from DQN current network.
              for trans, q_val in zip(transitions,q_estimates):
                trans.q_values = q_val # each have the size vocab_extended
            q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_steps(sess, batch, train_step, q_estimates)
          t1=time.time()
        else:
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_steps(sess, batch, train_step)
          t1=time.time()

        tf.logging.info('experiment: {}'.format(FLAGS.exp_name))
        tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0))

        printer_helper = {}
        loss = printer_helper['pgen_loss']= results['pgen_loss']
        printer_helper['rl_full_reward_sampled'] = np.mean(results['full_ssr'])
        printer_helper['rl_full_reward_greedy'] = np.mean(results['full_gsr'])
        printer_helper['rl_full_reward_diff'] = results['full_reward_diff']
        if FLAGS.rl_training:
          loss = np.mean([results['full_rl_avg_logprobs'], results['partial_rl_avg_logprobs']])
          printer_helper['shared_loss'] = results['shared_loss']
          printer_helper['rl_full_loss'] = results['full_rl_loss']
          printer_helper['rl_full_avg_logprobs'] = results['full_rl_avg_logprobs']
          printer_helper['rl_partial_loss'] = results['partial_rl_loss']
          printer_helper['rl_partial_avg_logprobs'] = results['partial_rl_avg_logprobs']
          printer_helper['rl_partial_reward_sampled'] = results['partial_ssr']
          printer_helper['rl_partial_reward_greedy'] = results['partial_gsr']
          printer_helper['rl_partial_reward_diff'] = results['partial_reward_diff']
        if FLAGS.coverage:
          loss = printer_helper['coverage_loss'] = results['coverage_loss']
          if FLAGS.rl_training or FLAGS.ac_training:
            loss = printer_helper['rl_cov_total_loss'] = results['reinforce_cov_total_loss']
          elif FLAGS.pointer_gen:
            loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
        if FLAGS.rl_training:
          greedy_rouges.append(np.mean([results['full_gsr'],results['partial_gsr']]))
          sampled_rouges.append(np.mean([results['full_ssr'],results['partial_ssr']]))
        else:
          greedy_rouges.append(np.mean(results['full_gsr']))
          sampled_rouges.append(np.mean(results['full_ssr']))

        for (k,v) in sorted(printer_helper.items(), key=lambda x: x[0]):
          if not np.isfinite(v):
            raise Exception("{} is not finite. Stopping.".format(k))
          tf.logging.info('{}: {}\t'.format(k,v))
        tf.logging.info('-------------------------------------------')
        time.sleep(2)

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        self.summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss))
        tf.logging.info('-------------------------------------------')
        time.sleep(2)

      running_avg_loss = np.mean(avg_losses)
      running_greedy_rouge = np.mean(greedy_rouges)
      running_sampled_rouge = np.mean(sampled_rouges)
      self.summary_writer.add_summary(
        tf.Summary(value=[tf.Summary.Value(tag="running_greedy_rouge", simple_value=running_greedy_rouge), ]),
        train_step)
      self.summary_writer.add_summary(
        tf.Summary(value=[tf.Summary.Value(tag="running_sampled_rouge", simple_value=running_sampled_rouge), ]),
        train_step)

      self.summary_writer.add_summary(tf.Summary(
          value=[tf.Summary.Value(tag='running_avg_loss/decay=%f' % (decay), simple_value=running_avg_loss), ]),
                                      train_step)

      tf.logging.info('==========================================')
      tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss))
      tf.logging.info('greedy rouges: {}\tsampled rouges: {}\t'.format(running_greedy_rouge, running_sampled_rouge))
      tf.logging.info('==========================================')

      # If running_avg_loss is best so far, save this checkpoint (early stopping).
      # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
      if (best_loss is None) or (FLAGS.rouge_based_eval and running_greedy_rouge > best_loss) or (
              not FLAGS.rouge_based_eval and running_avg_loss < best_loss):
        tf.logging.info('Found new best model with %.3f %s. Saving to %s',
                        running_greedy_rouge if FLAGS.rouge_based_eval else running_avg_loss,
                        "running_greedy_rouge" if FLAGS.rouge_based_eval else "running_avg_loss",
                        bestmodel_save_path)
        saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
        best_loss = running_greedy_rouge if FLAGS.rouge_based_eval else running_avg_loss
        time.sleep(15)

      # flush the summary writer every so often
      if train_step % 100 == 0:
        self.summary_writer.flush()
Beispiel #25
0
    def __init__(self, model, batcher, vocab):
        """Initialize decoder.

    Args:
      model: a SentSelector object.
      batcher: a Batcher object.
      vocab: Vocabulary object
    """
        # get the data split set
        if "train" in FLAGS.data_path: self._dataset = "train"
        elif "val" in FLAGS.data_path: self._dataset = "val"
        elif "test" in FLAGS.data_path: self._dataset = "test"
        else:
            raise ValueError(
                "FLAGS.data_path %s should contain one of train, val or test" %
                (FLAGS.data_path))

        # create the data loader
        self._batcher = batcher

        if FLAGS.eval_gt_rouge:  # no need to load model, default is fasle
            # Make a descriptive decode directory name
            self._decode_dir = os.path.join(FLAGS.log_root,
                                            'select_gt' + self._dataset)
            tf.logging.info('Save evaluation results to ' + self._decode_dir)
            if os.path.exists(self._decode_dir):
                raise Exception(
                    "single_pass decode directory %s should not already exist"
                    % self._decode_dir)

            # Make the decode dir
            os.makedirs(self._decode_dir)

            # Make the dirs to contain output written in the correct format for pyrouge
            self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
            if not os.path.exists(self._rouge_ref_dir):
                os.mkdir(self._rouge_ref_dir)
            self._rouge_gt_dir = os.path.join(self._decode_dir, "gt_selected")
            if not os.path.exists(self._rouge_gt_dir):
                os.mkdir(self._rouge_gt_dir)
        else:  # FALSE
            self._model = model
            self._model.build_graph()
            self._vocab = vocab
            self._saver = tf.train.Saver(
            )  # we use this to load checkpoints for decoding
            self._sess = tf.Session(config=util.get_config())

            # Load an initial checkpoint to use for decoding
            print(" eval_ckpt_path ", FLAGS.eval_ckpt_path)
            if FLAGS.load_best_eval_model:
                tf.logging.info('Loading best eval checkpoint')
                ckpt_path = util.load_ckpt(self._saver,
                                           self._sess,
                                           ckpt_dir='eval')
            elif FLAGS.eval_ckpt_path:
                ckpt_path = util.load_ckpt(self._saver,
                                           self._sess,
                                           ckpt_path=FLAGS.eval_ckpt_path,
                                           ckpt_dir='eval')
            else:
                tf.logging.info('Loading best train checkpoint')
                ckpt_path = util.load_ckpt(self._saver, self._sess)

            if FLAGS.single_pass:
                # Make a descriptive decode directory name
                ckpt_name = "ckpt-" + ckpt_path.split('-')[
                    -1]  # this is something of the form "ckpt-123456"
                decode_root_dir, decode_dir = get_decode_dir_name(
                    ckpt_name, self._dataset)
                self._decode_root_dir = os.path.join(FLAGS.log_root,
                                                     decode_root_dir)
                self._decode_dir = os.path.join(FLAGS.log_root,
                                                decode_root_dir, decode_dir)
                tf.logging.info('Save evaluation results to ' +
                                self._decode_dir)
                if os.path.exists(self._decode_dir):
                    raise Exception(
                        "single_pass decode directory %s should not already exist"
                        % self._decode_dir)
            else:  # Generic decode dir name
                self._decode_dir = os.path.join(FLAGS.log_root, "select")

            # Make the decode dir if necessary
            if not os.path.exists(self._decode_dir):
                os.makedirs(self._decode_dir)

            if FLAGS.single_pass:
                # Make the dirs to contain output written in the correct format for pyrouge
                self._rouge_ref_dir = os.path.join(self._decode_dir,
                                                   "reference")
                if not os.path.exists(self._rouge_ref_dir):
                    os.mkdir(self._rouge_ref_dir)
                self._rouge_dec_dir = os.path.join(self._decode_dir,
                                                   "selected")
                if not os.path.exists(self._rouge_dec_dir):
                    os.mkdir(self._rouge_dec_dir)
                if FLAGS.save_pkl:
                    self._result_dir = os.path.join(self._decode_dir,
                                                    "select_result")
                    if not os.path.exists(self._result_dir):
                        os.mkdir(self._result_dir)

                self._probs_pkl_path = os.path.join(self._decode_root_dir,
                                                    "probs.pkl")
                if not os.path.exists(self._probs_pkl_path):
                    self._make_probs_pkl = True
                else:
                    self._make_probs_pkl = False
                self._precision = []
                self._recall = []
                self._accuracy = []
                self._ratio = []
                self._select_sent_num = []
Beispiel #26
0
def main(unused_argv):
  if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
    raise Exception("Problem with flags: %s" % unused_argv)

  tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
  tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

  # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
  FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
  if not os.path.exists(FLAGS.log_root):
    if FLAGS.mode=="train":
      os.makedirs(FLAGS.log_root)
    else:
      raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))

  vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary


  # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
  hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num','max_dec_steps', 'max_enc_steps']
  hps_dict = {}
  for key,val in FLAGS.__flags.items(): # for each flag
    if key in hparam_list: # if it's in the list
      hps_dict[key] = val # add it to the dict
  hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  hparam_list = ['lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm',
                 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len']
  hps_dict = {}
  for key, val in FLAGS.__flags.items():  # for each flag
      if key in hparam_list:  # if it's in the list
          hps_dict[key] = val  # add it to the dict
  hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  # Create a batcher object that will create minibatches of data
  batcher = GenBatcher(vocab, hps_generator)




  tf.set_random_seed(111) # a seed value for randomness





  if hps_generator.mode == 'train':
    print("Start pre-training......")
    model = Generator(hps_generator, vocab)

    sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
    generated = Generated_sample(model, vocab, batcher, sess_ge)
    print("Start pre-training generator......")
    run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge, train_dir_ge,generated) # this is an infinite loop until 

    print("Generating negetive examples......")
    generated.generator_whole_negetive_example()
    generated.generator_test_negetive_example()

    model_dis = Discriminator(hps_discriminator, vocab)
    dis_batcher = DisBatcher(hps_discriminator, vocab, "train/generated_samples_positive/*", "train/generated_samples_negetive/*", "test/generated_samples_positive/*", "test/generated_samples_negetive/*")
    sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis)
    print("Start pre-training discriminator......")
    #run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test")
    run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis, saver_dis, train_dir_dis)

    util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
    
    generated.generator_sample_example("sample_temp_positive", "sample_temp_negetive", 1000)

    generated.generator_test_sample_example("test_sample_temp_positive",
                                       "test_sample_temp_negetive",
                                       200)
    generated.generator_test_max_example("test_max_temp_positive",
                                       "test_max_temp_negetive",
                                       200)
    tf.logging.info("true data diversity: ")
    eva = Evaluate()
    eva.diversity_evaluate("test_sample_temp_positive" + "/*")



    print("Start adversial training......")
    whole_decay = False
    for epoch in range(1):
        batches = batcher.get_batches(mode='train')
        for step in range(int(len(batches)/1000)):

            run_train_generator(model,model_dis,sess_dis,batcher,dis_batcher,batches[step*1000:(step+1)*1000],sess_ge, saver_ge, train_dir_ge,generated) #(model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated):
            generated.generator_sample_example("sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 1000)
            #generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200)

            tf.logging.info("test performance: ")
            tf.logging.info("epoch: "+str(epoch)+" step: "+str(step))
            generated.generator_test_sample_example(
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive", 200)
            generated.generator_test_max_example("test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                                            "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negetive",
                                            200)

            dis_batcher.train_queue = []
            dis_batcher.train_queue = []
            for i in range(epoch+1):
              for j in range(step+1):
                dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_positive/*")
                dis_batcher.train_queue += dis_batcher.fill_example_queue("sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_negetive/*")
            dis_batcher.train_batch = dis_batcher.create_batches(mode="train", shuffleis=True)

            #dis_batcher.valid_batch = dis_batcher.train_batch
            whole_decay = run_train_discriminator(model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"),
                                                  sess_dis, saver_dis, train_dir_dis, whole_decay)

  '''elif hps_generator.mode == 'decode':
    decode_model_hps = hps_generator  # This will be the hyperparameters for the decoder model
    model = Generator(decode_model_hps, vocab)
    generated = Generated_sample(model, vocab, batcher)
    bleu_score = generated.compute_BLEU()'=
    tf.logging.info('bleu: %f', bleu_score)  # print the loss to screen'''

  else:
Beispiel #27
0
    print('finetune')
    lr = opt.lr_finetune
    net.freeze_enc_bn = True
else:
    lr = opt.lr

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    net.parameters()),
                             lr=lr,
                             betas=(0.5, 0.999))

criterion = TotalLoss(VGG16FeatureExtractor()).to(device)

start_iter = 0
if opt.resume:
    start_iter = load_ckpt(opt.resume, [('model', net)],
                           [('optimizer', optimizer)])
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    print('Starting from iter ', start_iter)

# --------
# Training
# --------
for epoch in range(start_iter, opt.n_epochs):
    net.train()
    for i, (image, mask, gt) in enumerate(iterator_train):

        image, mask, gt = Variable(image).to(device), Variable(
            mask, requires_grad=False).to(device), Variable(gt).to(device)

        output, _ = net(image, mask)
Beispiel #28
0
def main(unused_argv):
    # %%
    # choose what level of logging you want
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))
    # 創建字典
    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)

    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_sen_num', 'max_dec_steps', 'max_enc_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():
        if key in hparam_list:
            hps_dict[key] = val.value  # add it to the dict
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
        'max_enc_sen_num', 'max_enc_seq_len'
    ]
    hps_dict = {}

    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:
            hps_dict[key] = val.value  # add it to the dict
    hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # # 取出最小batch size 的資料量
    batcher = GenBatcher(vocab, hps_generator)
    # print(batcher.train_batch[0].original_review_inputs)
    # print(len(batcher.train_batch[0].original_review_inputs))
    tf.set_random_seed(123)
    # %%
    if FLAGS.mode == 'train_generator':

        # print("Start pre-training ......")
        ge_model = Generator(hps_generator, vocab)
        sess_ge, saver_ge, train_dir_ge = setup_training_generator(ge_model)

        generated = Generated_sample(ge_model, vocab, batcher, sess_ge)
        print("Start pre-training generator......")
        run_pre_train_generator(ge_model, batcher, 300, sess_ge, saver_ge,
                                train_dir_ge)
        # util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        print("finish load train-generator")

        print("Generating negative examples......")
        generator_graph = tf.Graph()
        with generator_graph.as_default():
            util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
            print("finish load train-generator")

        generated.generator_train_negative_example()
        generated.generator_test_negative_example()

        print("finish write")
    elif FLAGS.mode == 'train_discriminator':
        # print("Start pre-training ......")
        model_dis = Discriminator(hps_discriminator, vocab)
        dis_batcher = DisBatcher(hps_discriminator, vocab,
                                 "discriminator_train/positive/*",
                                 "discriminator_train/negative/*",
                                 "discriminator_test/positive/*",
                                 "discriminator_test/negative/*")
        sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
            model_dis)

        print("Start pre-training discriminator......")
        if not os.path.exists("discriminator_result"):
            os.mkdir("discriminator_result")
        run_pre_train_discriminator(model_dis, dis_batcher, 1000, sess_dis,
                                    saver_dis, train_dir_dis)

    elif FLAGS.mode == "adversarial_train":

        generator_graph = tf.Graph()
        discriminatorr_graph = tf.Graph()

        print("Start adversarial-training......")
        # tf.reset_default_graph()

        with generator_graph.as_default():
            model = Generator(hps_generator, vocab)
            sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
            generated = Generated_sample(model, vocab, batcher, sess_ge)

            util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
            print("finish load train-generator")
        with discriminatorr_graph.as_default():
            model_dis = Discriminator(hps_discriminator, vocab)
            dis_batcher = DisBatcher(hps_discriminator, vocab,
                                     "discriminator_train/positive/*",
                                     "discriminator_train/negative/*",
                                     "discriminator_test/positive/*",
                                     "discriminator_test/negative/*")
            sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
                model_dis)

            util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator")
            print("finish load train-discriminator")

        print("Start adversarial  training......")
        if not os.path.exists("train_sample_generated"):
            os.mkdir("train_sample_generated")
        if not os.path.exists("test_max_generated"):
            os.mkdir("test_max_generated")
        if not os.path.exists("test_sample_generated"):
            os.mkdir("test_sample_generated")

        whole_decay = False

        for epoch in range(100):
            print('開始訓練')
            batches = batcher.get_batches(mode='train')
            for step in range(int(len(batches) / 14)):

                run_train_generator(model, model_dis, sess_dis, batcher,
                                    dis_batcher,
                                    batches[step * 14:(step + 1) * 14],
                                    sess_ge, saver_ge, train_dir_ge)
                generated.generator_sample_example(
                    "train_sample_generated/" + str(epoch) + "epoch_step" +
                    str(step) + "_temp_positive", "train_sample_generated/" +
                    str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                    14)

                tf.logging.info("test performance: ")
                tf.logging.info("epoch: " + str(epoch) + " step: " + str(step))

                #                print("evaluate the diversity of DP-GAN (decode based on  max probability)")
                #                generated.generator_test_sample_example(
                #                    "test_sample_generated/" +
                #                    str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                #                    "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 14)
                #
                #                print("evaluate the diversity of DP-GAN (decode based on sampling)")
                #                generated.generator_test_max_example(
                #                    "test_max_generated/" +
                #                    str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                #                    "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 14)

                dis_batcher.train_queue = []
                for i in range(epoch + 1):
                    for j in range(step + 1):
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_positive/*")
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_negative/*")
                dis_batcher.train_batch = dis_batcher.create_batches(
                    mode="train", shuffleis=True)
                whole_decay = run_train_discriminator(
                    model_dis, 5, dis_batcher,
                    dis_batcher.get_batches(mode="train"), sess_dis, saver_dis,
                    train_dir_dis, whole_decay)
    elif FLAGS.mode == "test_language_model":
        ge_model = Generator(hps_generator, vocab)
        sess_ge, saver_ge, train_dir_ge = setup_training_generator(ge_model)
        util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        print("finish load train-generator")

        #        generator_graph = tf.Graph()
        #        with generator_graph.as_default():
        #            util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        #            print("finish load train-generator")

        # jieba.load_userdict('dir.txt')
        inputs = ''
        while inputs != "close":

            inputs = input("Enter your ask: ")
            sentence = segmentor.segment(t2s.convert(inputs))
            #            sentence = jieba.cut(inputs)
            sentence = (" ".join(sentence))
            sentence = s2t.convert(sentence)
            print(sentence)
            sentence = sentence.split()

            enc_input = [vocab.word2id(w) for w in sentence]
            enc_lens = np.array([len(enc_input)])
            enc_input = np.array([enc_input])
            out_sentence = ('[START]').split()
            dec_batch = [vocab.word2id(w) for w in out_sentence]
            #dec_batch = [2] + dec_batch
            # dec_batch.append(3)
            while len(dec_batch) < 40:
                dec_batch.append(1)

            dec_batch = np.array([dec_batch])
            dec_batch = np.resize(dec_batch, (1, 1, 40))
            dec_lens = np.array([len(dec_batch)])
            if (FLAGS.beamsearch == 'beamsearch_train'):
                result = ge_model.run_test_language_model(
                    sess_ge, enc_input, enc_lens, dec_batch, dec_lens)
                #                print(result['generated'])
                #                print(result['generated'].shape)
                output_ids = result['generated'][0]
                decoded_words = data.outputids2words(output_ids, vocab, None)
                print("decoded_words :", decoded_words)
            else:
                results = ge_model.run_test_beamsearch_example(
                    sess_ge, enc_input, enc_lens, dec_batch, dec_lens)
                beamsearch_outputs = results['beamsearch_outputs']
                for i in range(5):
                    predict_list = np.ndarray.tolist(beamsearch_outputs[:, :,
                                                                        i])
                    predict_list = predict_list[0]
                    predict_seq = [vocab.id2word(idx) for idx in predict_list]
                    decoded_words = " ".join(predict_seq).split()
                    #                    decoded_words = decoded_words

                    try:
                        if decoded_words[0] == '[STOPDOC]':
                            decoded_words = decoded_words[1:]
                        # index of the (first) [STOP] symbol
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if decoded_words[-1] != '.' and decoded_words[
                            -1] != '!' and decoded_words[-1] != '?':
                        decoded_words.append('.')
                    decoded_words_all = []
                    decoded_output = ' '.join(
                        decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)
                    decoded_words_all = ' '.join(decoded_words_all).strip()
                    decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                    decoded_words_all = decoded_words_all.replace("[UNK]", "")
                    decoded_words_all = decoded_words_all.replace(" ", "")
                    decoded_words_all, _ = re.subn(r"(! ){2,}", "",
                                                   decoded_words_all)
                    decoded_words_all, _ = re.subn(r"(\. ){2,}", "",
                                                   decoded_words_all)
                    if decoded_words_all.startswith(','):
                        decoded_words_all = decoded_words_all[1:]
                    print("The resonse   : {}".format(decoded_words_all))
Beispiel #29
0
def main(unused_argv):
  if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
    raise Exception("Problem with flags: %s" % unused_argv)

  tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
  tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

  # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
  FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
  if not os.path.exists(FLAGS.log_root):
    if "train" in FLAGS.mode:
      os.makedirs(FLAGS.log_root)
    else:
      raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))

  vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary


  # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
  hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num','max_dec_steps', 'max_enc_steps']
  hps_dict = {}
  for key,val in FLAGS.__flags.items(): # for each flag
    if key in hparam_list: # if it's in the list
      hps_dict[key] = val # add it to the dict
  hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  hparam_list = ['lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm',
                 'hidden_dim', 'emb_dim', 'batch_size', 'max_enc_sen_num', 'max_enc_seq_len']
  hps_dict = {}
  for key, val in FLAGS.__flags.items():  # for each flag
      if key in hparam_list:  # if it's in the list
          hps_dict[key] = val  # add it to the dict
  hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

  # Create a batcher object that will create minibatches of data
  batcher = GenBatcher(vocab, hps_generator)




  tf.set_random_seed(111) # a seed value for randomness





  if hps_generator.mode == 'adversarial_train':
    print("Start pre-training......")
    model = Generator(hps_generator, vocab)

    sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
    generated = Generated_sample(model, vocab, batcher, sess_ge)


    model_dis = Discriminator(hps_discriminator, vocab)
    dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*")
    sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis)
    
    
    util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator")

    util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
    
    
   
    if not os.path.exists("MLE"): os.mkdir("MLE")

    print("evaluate the diversity of MLE (decode based on sampling)")
    generated.generator_test_sample_example("MLE/"+"MLE_sample_positive",
                                       "MLE/"+"MLE_sample_negative",
                                       200)
                                       
    print("evaluate the diversity of MLE (decode based on max probability)")
    generated.generator_test_max_example("MLE/"+"MLE_max_temp_positive",
                                       "MLE/"+"MLE_max_temp_negative",
                                       200)
  

    print("Start adversarial  training......")
    if not os.path.exists("train_sample_generated"): os.mkdir("train_sample_generated")
    if not os.path.exists("test_max_generated"): os.mkdir("test_max_generated")
    if not os.path.exists("test_sample_generated"): os.mkdir("test_sample_generated")
    
    
    
    whole_decay = False
    for epoch in range(10):
        batches = batcher.get_batches(mode='train')
        for step in range(int(len(batches)/1000)):

            run_train_generator(model,model_dis,sess_dis,batcher,dis_batcher,batches[step*1000:(step+1)*1000],sess_ge, saver_ge, train_dir_ge,generated) #(model, discirminator_model, discriminator_sess, batcher, dis_batcher, batches, sess, saver, train_dir, generated):
            generated.generator_sample_example("train_sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "train_sample_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negative", 1000)
            #generated.generator_max_example("max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_positive", "max_generated/"+str(epoch)+"epoch_step"+str(step)+"_temp_negetive", 200)

            tf.logging.info("test performance: ")
            tf.logging.info("epoch: "+str(epoch)+" step: "+str(step))
            print("evaluate the diversity of DP-GAN (decode based on  max probability)")
            generated.generator_test_sample_example(
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 200)
            print("evaluate the diversity of DP-GAN (decode based on sampling)")
            generated.generator_test_max_example("test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                                            "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                                            200)

            dis_batcher.train_queue = []
            dis_batcher.train_queue = []
            for i in range(epoch+1):
              for j in range(step+1):
                dis_batcher.train_queue += dis_batcher.fill_example_queue("train_sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_positive/*")
                dis_batcher.train_queue += dis_batcher.fill_example_queue("train_sample_generated/"+str(i)+"epoch_step"+str(j)+"_temp_negative/*")
            dis_batcher.train_batch = dis_batcher.create_batches(mode="train", shuffleis=True)

            #dis_batcher.valid_batch = dis_batcher.train_batch
            whole_decay = run_train_discriminator(model_dis, 5, dis_batcher, dis_batcher.get_batches(mode="train"),
                                                  sess_dis, saver_dis, train_dir_dis, whole_decay)

  elif hps_generator.mode == 'train_generator':
    print("Start pre-training......")
    model = Generator(hps_generator, vocab)

    sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
    generated = Generated_sample(model, vocab, batcher, sess_ge)
    print("Start pre-training generator......")
    run_pre_train_generator(model, batcher, 10, sess_ge, saver_ge, train_dir_ge,generated) # this is an infinite loop until 

    print("Generating negative examples......")
    generated.generator_train_negative_example()
    generated.generator_test_negative_example()
  elif hps_generator.mode == 'train_discriminator':
    print("Start pre-training......")
    model = Generator(hps_generator, vocab)

    sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
    
    #util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")

    model_dis = Discriminator(hps_discriminator, vocab)
    dis_batcher = DisBatcher(hps_discriminator, vocab, "discriminator_train/positive/*", "discriminator_train/negative/*", "discriminator_test/positive/*", "discriminator_test/negative/*")
    sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(model_dis)
    print("Start pre-training discriminator......")
    #run_test_discriminator(model_dis, dis_batcher, sess_dis, saver_dis, "test")
    if not os.path.exists("discriminator_result"): os.mkdir("discriminator_result")
    run_pre_train_discriminator(model_dis, dis_batcher, 25, sess_dis, saver_dis, train_dir_dis)
    def __init__(self,
                 model,
                 batcher,
                 vocab,
                 lm_model=None,
                 lm_word2idx=None,
                 lm_idx2word=None):
        """Initialize decoder.

    Args:
      model: a Seq2SeqAttentionModel object.
      batcher: a Batcher object.
      vocab: Vocabulary object
    """
        self._model = model
        self._model.build_graph()
        self._batcher = batcher
        self._vocab = vocab
        self._saver = tf.train.Saver(
        )  # we use this to load checkpoints for decoding
        self._sess = tf.Session(config=util.get_config())

        # these keep running stats during decoding
        self.total_nsentence = 0
        self.total_length = 0
        self.overlap_dict = dict()

        # these are for external lm
        self._lm_model = lm_model
        self._lm_word2idx = lm_word2idx
        self._lm_idx2word = lm_idx2word

        # Load an initial checkpoint to use for decoding
        ckpt_path = util.load_ckpt(self._saver, self._sess)

        if FLAGS.single_pass:
            # Make a descriptive decode directory name
            if FLAGS.pointer_gen_only_vocab:
                ckpt_name = "vocab-ckpt-" + ckpt_path.split('-')[
                    -1]  # this is something of the form "ckpt-123456"
            elif FLAGS.pointer_gen_only_attn:
                ckpt_name = "attn-ckpt-" + ckpt_path.split('-')[-1]

            else:
                ckpt_name = "pre_cov_val_ckpt-" + ckpt_path.split('-')[-1]

            self._decode_dir = os.path.join(FLAGS.log_root,
                                            get_decode_dir_name(ckpt_name))

            if os.path.exists(self._decode_dir):
                raise Exception(
                    "single_pass decode directory %s should not already exist"
                    % self._decode_dir)

        else:  # Generic decode dir name
            self._decode_dir = os.path.join(FLAGS.log_root, "decode")

        # Make the decode dir if necessary
        if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

        if FLAGS.single_pass:
            # Make the dirs to contain output written in the correct format for pyrouge
            self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
            if not os.path.exists(self._rouge_ref_dir):
                os.mkdir(self._rouge_ref_dir)
            self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
            if not os.path.exists(self._rouge_dec_dir):
                os.mkdir(self._rouge_dec_dir)
def run_eval(model, batcher, vocab):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph()  # build the graph
    saver = tf.train.Saver(
        max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())
    eval_dir = os.path.join(
        FLAGS.log_root, "eval")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)
    running_avg_loss = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = None  # will hold the best loss achieved so far

    while True:
        _ = util.load_ckpt(saver, sess)  # load a new checkpoint
        batch = batcher.next_batch()  # get the next batch
        extra_input = {}
        if 'key_phrases' in model._extra_info:
            # TODO: calculate the key phrase input part for this specific batch
            raise NotImplementedError(
                'Key phrases part has not been implemented here!')

        # run eval on the batch
        t0 = time.time()
        results = model.run_eval_step(sess, batch, extra_input=extra_input)
        t1 = time.time()
        tf.logging.info('seconds for batch: %.2f', t1 - t0)

        # print the loss and coverage loss to screen
        loss = results['loss']
        tf.logging.info('loss: %f', loss)
        if FLAGS.coverage:
            coverage_loss = results['coverage_loss']
            tf.logging.info("coverage_loss: %f", coverage_loss)

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        running_avg_loss = calc_running_avg_loss(np.asscalar(loss),
                                                 running_avg_loss,
                                                 summary_writer, train_step)

        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        if best_loss is None or running_avg_loss < best_loss:
            tf.logging.info(
                'Found new best model with %.3f running_avg_loss. Saving to %s',
                running_avg_loss, bestmodel_save_path)
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')
            best_loss = running_avg_loss

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()
Beispiel #32
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    tf.logging.set_verbosity(
        tf.logging.INFO)  # choose what level of logging you want
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    if not os.path.exists(FLAGS.log_root):
        if FLAGS.mode == "train":
            os.makedirs(FLAGS.log_root)
        else:
            raise Exception(
                "Logdir %s doesn't exist. Run in train mode to create it." %
                (FLAGS.log_root))

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'vocab_size', 'dataset', 'mode', 'lr', 'adagrad_init_acc',
        'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm',
        'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_sen_num',
        'max_enc_num', 'max_dec_steps', 'max_enc_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
        'srl_max_dec_seq_len', 'srl_max_dec_sen_num', 'srl_max_enc_seq_len',
        'srl_max_enc_sen_num'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_srl_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
        'sc_max_dec_seq_len', 'sc_max_enc_seq_len'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val  # add it to the dict
    hps_sc_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # Create a batcher object that will create minibatches of data

    sc_batcher = Sc_GenBatcher(vocab, hps_sc_generator)

    tf.set_random_seed(111)  # a seed value for randomness

    if hps_generator.mode == 'train':

        print("Start pre-training......")
        sc_model = Sc_Generator(hps_sc_generator, vocab)

        sess_sc, saver_sc, train_dir_sc = setup_training_sc_generator(sc_model)
        sc_generated = Generated_sc_sample(sc_model, vocab, sess_sc)
        print("Start pre-training generator......")
        run_pre_train_sc_generator(sc_model, sc_batcher, 40, sess_sc, saver_sc,
                                   train_dir_sc, sc_generated)

        if not os.path.exists("data/" + str(0) + "/"):
            os.mkdir("data/" + str(0) + "/")
        sc_generated.generator_max_example_test(
            sc_batcher.get_batches("pre-train"),
            "data/" + str(0) + "/train_skeleton.txt")

        sc_generated.generator_max_example_test(
            sc_batcher.get_batches("pre-valid"),
            "data/" + str(0) + "/valid_skeleton.txt")

        sc_generated.generator_max_example_test(
            sc_batcher.get_batches("pre-test"),
            "data/" + str(0) + "/test_skeleton.txt")

        merge("data/story/train_process.txt", "data/0/train_skeleton.txt",
              "data/0/train.txt")
        merge("data/story/validation_process.txt", "data/0/valid_skeleton.txt",
              "data/0/valid.txt")
        merge("data/story/test_process.txt", "data/0/test_skeleton.txt",
              "data/0/test.txt")

        #################################################################################################
        batcher = GenBatcher(vocab, hps_generator)
        srl_batcher = Srl_GenBatcher(vocab, hps_srl_generator)
        print("Start pre-training......")
        model = Generator(hps_generator, vocab)

        sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
        generated = Generated_sample(model, vocab, sess_ge)
        print("Start pre-training generator......")
        run_pre_train_generator(model, batcher, 30, sess_ge, saver_ge,
                                train_dir_ge, generated)
        ##################################################################################################
        srl_generator_model = Srl_Generator(hps_srl_generator, vocab)

        sess_srl_ge, saver_srl_ge, train_dir_srl_ge = setup_training_srl_generator(
            srl_generator_model)
        util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        util.load_ckpt(saver_sc, sess_sc, ckpt_dir="train-sc-generator")
        srl_generated = Generated_srl_sample(srl_generator_model, vocab,
                                             sess_srl_ge)
        whole_generated = Generated_whole_sample(model, srl_generator_model,
                                                 vocab, sess_ge, sess_srl_ge,
                                                 batcher, srl_batcher)
        print("Start pre-training srl_generator......")
        run_pre_train_srl_generator(srl_generator_model, batcher, srl_batcher,
                                    20, sess_srl_ge, saver_srl_ge,
                                    train_dir_srl_ge, srl_generated,
                                    whole_generated)

        loss_window = 0
        t0 = time.time()
        print("begin reinforcement learning:")
        for epoch in range(10):

            loss_window = 0.0

            batcher = GenBatcher(vocab, hps_generator)
            srl_batcher = Srl_GenBatcher(vocab, hps_srl_generator)

            batches = batcher.get_batches(mode='train')
            srl_batches = srl_batcher.get_batches(mode='train')
            sc_batches = sc_batcher.get_batches(mode='train')
            len_sc = len(sc_batches)

            for i in range(min(len(batches), len(srl_batches))):
                current_batch = batches[i]
                current_srl_batch = srl_batches[i]
                current_sc_batch = sc_batches[i % (len_sc - 1)]

                results = model.run_pre_train_step(sess_ge, current_batch)
                loss_list = results['without_average_loss']

                example_skeleton_list = current_batch.original_review_outputs
                example_text_list = current_batch.original_target_sentences

                new_batch = sc_batcher.get_text_queue(example_skeleton_list,
                                                      example_text_list,
                                                      loss_list)
                results_sc = sc_model.run_rl_train_step(sess_sc, new_batch)
                loss = results_sc['loss']
                loss_window += loss

                results_srl = srl_generator_model.run_pre_train_step(
                    sess_srl_ge, current_srl_batch)
                loss_list_srl = results_srl['without_average_loss']

                example_srl_text_list = current_srl_batch.orig_outputs
                example_skeleton_srl_list = current_srl_batch.orig_inputs

                new_batch = sc_batcher.get_text_queue(
                    example_skeleton_srl_list, example_srl_text_list,
                    loss_list_srl)
                results_sc = sc_model.run_rl_train_step(sess_sc, new_batch)
                loss = results_sc['loss']
                loss_window += loss

                results_sc = sc_model.run_rl_train_step(
                    sess_sc, current_sc_batch)
                loss = results_sc['loss']
                loss_window += loss

                train_step = results['global_step']

                if train_step % 100 == 0:
                    t1 = time.time()
                    tf.logging.info(
                        'seconds for %d training generator step: %.3f ',
                        train_step, (t1 - t0) / 300)
                    t0 = time.time()
                    tf.logging.info('loss: %f', loss_window /
                                    100)  # print the loss to screen
                    loss_window = 0.0

                train_srl_step = results_srl['global_step']

                if train_srl_step % 10000 == 0:
                    saver_sc.save(sess_sc,
                                  train_dir_sc + "/model",
                                  global_step=results_sc['global_step'])
                    saver_ge.save(sess_ge,
                                  train_dir_ge + "/model",
                                  global_step=train_step)
                    saver_srl_ge.save(sess_srl_ge,
                                      train_dir_srl_ge + "/model",
                                      global_step=train_srl_step)

                    srl_generated.generator_max_example(
                        srl_batcher.get_batches("validation"),
                        "to_seq_max_generated/valid/" +
                        str(int(train_srl_step / 30000)) + "_positive",
                        "to_seq_max_generated/valid/" +
                        str(int(train_srl_step / 30000)) + "_negative")
                    srl_generated.generator_max_example(
                        srl_batcher.get_batches("test"),
                        "to_seq_max_generated/test/" +
                        str(int(train_srl_step / 30000)) + "_positive",
                        "to_seq_max_generated/test/" +
                        str(int(train_srl_step / 30000)) + "_negative")

                    whole_generated.generator_max_example(
                        batcher.get_batches("test-validation"),
                        "max_generated_final/valid/" +
                        str(int(train_srl_step / 30000)) + "_positive",
                        "max_generated_final/valid/" +
                        str(int(train_srl_step / 30000)) + "_negative")
                    whole_generated.generator_max_example(
                        batcher.get_batches("test-test"),
                        "max_generated_final/test/" +
                        str(int(train_srl_step / 30000)) + "_positive",
                        "max_generated_final/test/" +
                        str(int(train_srl_step / 30000)) + "_negative")

            sc_generated.generator_max_example_test(
                sc_batcher.get_batches("pre-train"),
                "data/" + str(0) + "/train_skeleton.txt")

            sc_generated.generator_max_example_test(
                sc_batcher.get_batches("pre-valid"),
                "data/" + str(0) + "/valid_skeleton.txt")

            sc_generated.generator_max_example_test(
                sc_batcher.get_batches("pre-test"),
                "data/" + str(0) + "/test_skeleton.txt")

            merge("data/story/train_process.txt", "data/0/train_skeleton.txt",
                  "data/0/train.txt")
            merge("data/story/validation_process.txt",
                  "data/0/valid_skeleton.txt", "data/0/valid.txt")
            merge("data/story/test_process.txt", "data/0/test_skeleton.txt",
                  "data/0/test.txt")

    else:
        raise ValueError("The 'mode' flag must be one of train/eval/decode")
Beispiel #33
0
def run_eval(model, batcher, word_vector):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph()  # build the graph
    saver = tf.train.Saver(
        max_to_keep=3)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())

    if FLAGS.embedding:
        sess.run(tf.global_variables_initializer(),
                 feed_dict={model.embedding_place: word_vector})
    eval_dir = os.path.join(
        FLAGS.log_root, "eval")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)
    running_avg_loss = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = restore_best_eval_model(
    )  # will hold the best loss achieved so far
    train_step = 0

    while True:
        _ = util.load_ckpt(saver, sess)  # load a new checkpoint
        processed_batch = 0
        avg_losses = []
        # evaluate for 100 * batch_size before comparing the loss
        # we do this due to memory constraint, best to run eval on different machines with large batch size
        # while processed_batch < 100 * FLAGS.batch_size:
        processed_batch += FLAGS.batch_size
        batch = batcher.next_batch()  # get the next batch
        tf.logging.info('run eval step on seq2seq model.')
        t0 = time.time()
        results = model.run_eval_step(sess, batch, train_step)
        t1 = time.time()

        tf.logging.info('experiment: {}'.format(FLAGS.exp_name))
        tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(
            processed_batch, t1 - t0))

        printer_helper = {}
        loss = printer_helper['pgen_loss'] = results['pgen_loss']
        if FLAGS.coverage:
            printer_helper['coverage_loss'] = results['coverage_loss']
            if FLAGS.rl_training or FLAGS.ac_training:
                printer_helper['rl_cov_total_loss'] = results[
                    'reinforce_cov_total_loss']
            loss = printer_helper['pointer_cov_total_loss'] = results[
                'pointer_cov_total_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
            printer_helper['shared_loss'] = results['shared_loss']
            printer_helper['rl_loss'] = results['rl_loss']
            printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
        if FLAGS.rl_training:
            printer_helper['sampled_r'] = np.mean(
                results['sampled_sentence_r_values'])
            printer_helper['greedy_r'] = np.mean(
                results['greedy_sentence_r_values'])
            printer_helper['r_diff'] = printer_helper[
                'greedy_r'] - printer_helper['sampled_r']

        for (k, v) in printer_helper.items():
            if not np.isfinite(v):
                raise Exception("{} is not finite. Stopping.".format(k))
            tf.logging.info('{}: {}\t'.format(k, v))

            # add summaries
            summaries = results['summaries']
            train_step = results['global_step']
            print(train_step)
            summary_writer.add_summary(summaries, train_step)

            # calculate running avg loss
            avg_losses.append(
                calc_running_avg_loss(np.asscalar(loss), running_avg_loss,
                                      train_step, summary_writer))
            tf.logging.info('-------------------------------------------')

        running_avg_loss = np.mean(avg_losses)
        tf.logging.info('==========================================')
        tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(
            best_loss, running_avg_loss))
        tf.logging.info('==========================================')

        # If running_avg_loss is best so far, save this checkpoint (early stopping).
        # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
        if best_loss is None or running_avg_loss < best_loss:
            print(train_step)
            tf.logging.info(
                'Found new best model with %.3f running_avg_loss. Saving to %s',
                running_avg_loss, bestmodel_save_path)
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')
            best_loss = running_avg_loss

        # flush the summary writer every so often
        if train_step % 100 == 0:
            summary_writer.flush()
Beispiel #34
0
def run_eval(model, batcher, re_vocab, embed):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    model.build_graph(embed)  # build the graph
    saver = tf.train.Saver(
        max_to_keep=10)  # we will keep 3 best checkpoints at a time
    sess = tf.Session(config=util.get_config())
    eval_dir = os.path.join(
        FLAGS.log_root, "eval")  # make a subdir of the root dir for eval data
    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel')  # this is where checkpoints of best models are saved
    summary_writer = tf.summary.FileWriter(eval_dir)
    running_avg_loss = 0  # the eval job keeps a smoother, running average loss to tell it when to implement early stopping
    best_loss = None  # will hold the best loss achieved so far
    count = 0

    while True:
        _ = util.load_ckpt(saver, sess)  # load a new checkpoint
        batch = batcher.next_batch()  # get the next batch

        if not os.path.exists(FLAGS.model_dir):
            os.makedirs(FLAGS.model_dir)
            os.makedirs(FLAGS.system_dir)

        # run eval on the batch
        t0 = time.time()
        #    results = model.run_eval_step(sess, batch)
        step_output = model.run_eval_step(sess, batch)
        t1 = time.time()

        #tf.logging.info('seconds for batch: %.2f', t1-t0)
        (summaries, loss, train_step) = step_output[0]
        (out_decoder_outputs, out_sent_decoder_outputs,
         final_dists) = step_output[1]
        (step_loss, word_loss, sent_loss, word_loss_null, sent_loss_null,
         switch_loss) = step_output[2]
        coverage_loss = 0.0

        running_avg_loss = calc_running_avg_loss(np.asscalar(loss),
                                                 running_avg_loss,
                                                 summary_writer, train_step)

        if best_loss is None or running_avg_loss < (best_loss):
            if best_loss is None:
                best_loss = 0.0
            tf.logging.info(
                'Found new best model with %.3f running_avg_loss. Saving to %s',
                running_avg_loss, bestmodel_save_path)
            saver.save(sess,
                       bestmodel_save_path,
                       global_step=train_step,
                       latest_filename='checkpoint_best')

            best_loss = running_avg_loss
            last_step = train_step

        tf.logging.info('loss: %f rloss: %f', loss,
                        rloss)  # print the loss to screen
        tf.logging.info(
            'step_loss: %f word_loss: %f ,sent_loss: %f ,word_loss_null: %f,sent_loss_null: %f ,switch_loss: %f,cover_loss: %f',
            step_loss, word_loss, sent_loss, word_loss_null, sent_loss_null,
            switch_loss, coverage_loss)

        os.system("rm -rf " + FLAGS.model_dir + ' ' + FLAGS.system_dir)

        count = count + 1

        if train_step % 100 == 0:
            summary_writer.flush()
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        total = len(glob.glob(self._batcher._data_path)) * 1000
        pbar = tqdm(total=total)
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                logging.info("Output has been saved in %s and %s.",
                             self._rouge_ref_dir, self._rouge_dec_dir)
                if len(os.listdir(self._rouge_ref_dir)) != 0:
                    logging.info("Now starting ROUGE eval...")
                    results_dict = rouge_functions.rouge_eval(
                        self._rouge_ref_dir, self._rouge_dec_dir)
                    rouge_functions.rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            all_original_abstract_sents = batch.all_original_abstracts_sents[0]
            raw_article_sents = batch.raw_article_sents[0]

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            decoded_words, decoded_output, best_hyp = decode_example(
                self._sess, self._model, self._vocab, batch, counter,
                self._batcher._hps)

            if FLAGS.single_pass:
                if counter < 1000:
                    self.write_for_human(raw_article_sents,
                                         all_original_abstract_sents,
                                         decoded_words, counter)
                rouge_functions.write_for_rouge(
                    all_original_abstract_sents,
                    None,
                    counter,
                    self._rouge_ref_dir,
                    self._rouge_dec_dir,
                    decoded_words=decoded_words
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                if FLAGS.attn_vis:
                    self.write_for_attnvis(
                        article_withunks, abstract_withunks, decoded_words,
                        best_hyp.attn_dists, best_hyp.p_gens, counter
                    )  # write info to .json file for visualization tool

                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens,
                    counter)  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
            pbar.update(1)
        pbar.close()
Beispiel #36
0
    def decode(self):
        """Decode examples until data is exhausted (if FLAGS.single_pass) and return, or decode indefinitely, loading latest checkpoint at regular intervals"""
        t0 = time.time()
        counter = 0
        while True:
            batch = self._batcher.next_batch(
            )  # 1 example repeated across batch
            if batch is None:  # finished decoding dataset in single_pass mode
                assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
                tf.logging.info(
                    "Decoder has finished reading dataset for single_pass.")
                tf.logging.info(
                    "Output has been saved in %s and %s. Now starting ROUGE eval...",
                    self._rouge_ref_dir, self._rouge_dec_dir)
                results_dict = rouge_eval(self._rouge_ref_dir,
                                          self._rouge_dec_dir)
                rouge_log(results_dict, self._decode_dir)
                return

            original_article = batch.original_articles[0]  # string
            original_abstract = batch.original_abstracts[0]  # string
            original_abstract_sents = batch.original_abstracts_sents[
                0]  # list of strings

            article_withunks = data.show_art_oovs(original_article,
                                                  self._vocab)  # string
            abstract_withunks = data.show_abs_oovs(
                original_abstract, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))  # string

            # Run beam search to get best Hypothesis
            best_hyp = beam_search.run_beam_search(self._sess, self._model,
                                                   self._vocab, batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_hyp.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self._vocab,
                (batch.art_oovs[0] if FLAGS.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(
                    data.STOP_DECODING)  # index of the (first) [STOP] symbol
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            decoded_output = ' '.join(decoded_words)  # single string

            if FLAGS.single_pass:
                self.write_for_rouge(
                    original_abstract_sents, decoded_words, counter
                )  # write ref summary and decoded summary to file, to eval with pyrouge later
                counter += 1  # this is how many examples we've decoded
            else:
                print_results(article_withunks, abstract_withunks,
                              decoded_output)  # log output to screen
                self.write_for_attnvis(
                    article_withunks, abstract_withunks, decoded_words,
                    best_hyp.attn_dists, best_hyp.p_gens
                )  # write info to .json file for visualization tool

                # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
                t1 = time.time()
                if t1 - t0 > SECS_UNTIL_NEW_CKPT:
                    tf.logging.info(
                        'We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint',
                        t1 - t0)
                    _ = util.load_ckpt(self._saver, self._sess)
                    t0 = time.time()
def run_epoch(lm,
              session,
              batches,
              summary_writer,
              train_dir,
              train_step,
              saver,
              hps,
              best_loss,
              avg_loss,
              save_all=True):
    print_interval = 10.0
    total_batches = 0
    total_words = 0

    start_time = time.time()
    tick_time = start_time  # for showing status
    i = 0
    exception_count = 0
    init_exception_count = 0
    batches_skipped = 0

    for batch in batches:
        try:
            results = runTrainStep(lm, session, batch)

            loss = results['loss']
            coverage_loss = results['coverage_loss']
            if not np.isfinite(loss):
                raise Exception("Loss is not finite.")

            summaries = results['summaries']
            train_step = results['global_step']
            summary_writer.add_summary(summaries, train_step)

            avg_loss = util.running_avg_loss(np.asscalar(loss), avg_loss,
                                             summary_writer, train_step)
            if best_loss is None or avg_loss < best_loss:
                #saver.save(session, train_dir, global_step=train_step, latest_filename='checkpoint_best')
                best_loss = avg_loss

            total_batches = i + 1
            total_words += len(batch.original_articles)
            i = i + 1

            if (time.time() - tick_time >= print_interval):
                avg_wps = total_words / (time.time() - start_time)
                print(
                    "    [batch {:d}]: seen {:d} examples : {:.1f} eps, Loss: {:.3f}, Avg loss: {:.3f}, Best loss: {:.3f}, cov loss: {:.3f}"
                    .format(i, total_words, avg_wps, loss, avg_loss, best_loss,
                            coverage_loss))
                tick_time = time.time()  # reset time ticker
                if save_all:
                    saver.save(session,
                               train_dir,
                               global_step=train_step,
                               latest_filename='checkpoint')

            if train_step % 100 == 0:
                summary_writer.flush()

        except Exception as e:
            if (exception_count <= 10):
                print(f'    [EXCEPTION]: ', str(e), '; Restoring model params')
                exception_count = exception_count + 1
                batches_skipped = batches_skipped + 1
                util.load_ckpt(saver, session, hps, hps.log_root)
                continue
            else:
                print('    [EXCEPTION LIMIT EXCEEDED]: Batches skipped:',
                      batches_skipped, '; Error : ', str(e))
                raise e
    time_total = pretty_timedelta(since=start_time)
    print(
        f"    [END] Training complete: Total examples : {total_words}; Total time: {time_total}"
    )
    return avg_loss, best_loss, train_step