Example #1
0
def generate_samples(hparams, data, id_to_word, log_dir, output_file):
    """"Generate samples.

    Args:
      hparams:  Hyperparameters for the MaskGAN.
      data: Data to evaluate.
      id_to_word: Dictionary of indices to words.
      log_dir: Log directory.
      output_file:  Output file for the samples.
  """
    # Boolean indicating operational mode.
    is_training = False

    # Set a random seed to keep fixed mask.
    np.random.seed(0)

    with tf.Graph().as_default():
        # Construct the model.
        model = train_mask_gan.create_MaskGAN(hparams, is_training)

        ## Retrieve the initial savers.
        init_savers = model_utils.retrieve_init_savers(hparams)

        ## Initial saver function to supervisor.
        init_fn = partial(model_utils.init_fn, init_savers)

        is_chief = FLAGS.task == 0

        # Create the supervisor.  It will take care of initialization, summaries,
        # checkpoints, and recovery.
        sv = tf.Supervisor(logdir=log_dir,
                           is_chief=is_chief,
                           saver=model.saver,
                           global_step=model.global_step,
                           recovery_wait_secs=30,
                           summary_op=None,
                           init_fn=init_fn)

        # Get an initialized, and possibly recovered session.  Launch the
        # services: Checkpointing, Summaries, step counting.
        #
        # When multiple replicas of this program are running the services are
        # only launched by the 'chief' replica.
        with sv.managed_session(FLAGS.master,
                                start_standard_services=False) as sess:

            # Generator statefulness over the epoch.
            [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
                [model.eval_initial_state, model.fake_gen_initial_state])

            for n in xrange(FLAGS.number_epochs):
                print('Epoch number: %d' % n)
                # print('Percent done: %.2f' % float(n) / float(FLAGS.number_epochs))
                iterator = get_iterator(data)
                for x, y, _ in iterator:
                    if FLAGS.eval_language_model:
                        is_present_rate = 0.
                    else:
                        is_present_rate = FLAGS.is_present_rate
                    tf.logging.info('Evaluating on is_present_rate=%.3f.' %
                                    is_present_rate)

                    model_utils.assign_percent_real(sess,
                                                    model.percent_real_update,
                                                    model.new_rate,
                                                    is_present_rate)

                    # Randomly mask out tokens.
                    p = model_utils.generate_mask()

                    eval_feed = {
                        model.inputs: x,
                        model.targets: y,
                        model.present: p
                    }

                    if FLAGS.data_set == 'ptb':
                        # Statefulness for *evaluation* Generator.
                        for i, (c, h) in enumerate(model.eval_initial_state):
                            eval_feed[c] = gen_initial_state_eval[i].c
                            eval_feed[h] = gen_initial_state_eval[i].h

                        # Statefulness for the Generator.
                        for i, (c,
                                h) in enumerate(model.fake_gen_initial_state):
                            eval_feed[c] = fake_gen_initial_state_eval[i].c
                            eval_feed[h] = fake_gen_initial_state_eval[i].h

                    [gen_initial_state_eval, fake_gen_initial_state_eval,
                     _] = sess.run([
                         model.eval_final_state, model.fake_gen_final_state,
                         model.global_step
                     ],
                                   feed_dict=eval_feed)

                    generate_logs(sess, model, output_file, id_to_word,
                                  eval_feed)
            output_file.close()
            print('Closing output_file.')
            return
Example #2
0
def generate_samples(hparams, data, id_to_word, log_dir, output_file):
  """"Generate samples.

    Args:
      hparams:  Hyperparameters for the MaskGAN.
      data: Data to evaluate.
      id_to_word: Dictionary of indices to words.
      log_dir: Log directory.
      output_file:  Output file for the samples.
  """
  # Boolean indicating operational mode.
  is_training = False

  # Set a random seed to keep fixed mask.
  np.random.seed(0)

  with tf.Graph().as_default():
    # Construct the model.
    model = train_mask_gan.create_MaskGAN(hparams, is_training)

    ## Retrieve the initial savers.
    init_savers = model_utils.retrieve_init_savers(hparams)

    ## Initial saver function to supervisor.
    init_fn = partial(model_utils.init_fn, init_savers)

    is_chief = FLAGS.task == 0

    # Create the supervisor.  It will take care of initialization, summaries,
    # checkpoints, and recovery.
    sv = tf.Supervisor(
        logdir=log_dir,
        is_chief=is_chief,
        saver=model.saver,
        global_step=model.global_step,
        recovery_wait_secs=30,
        summary_op=None,
        init_fn=init_fn)

    # Get an initialized, and possibly recovered session.  Launch the
    # services: Checkpointing, Summaries, step counting.
    #
    # When multiple replicas of this program are running the services are
    # only launched by the 'chief' replica.
    with sv.managed_session(
        FLAGS.master, start_standard_services=False) as sess:

      # Generator statefulness over the epoch.
      [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
          [model.eval_initial_state, model.fake_gen_initial_state])

      for n in xrange(FLAGS.number_epochs):
        print('Epoch number: %d' % n)
        # print('Percent done: %.2f' % float(n) / float(FLAGS.number_epochs))
        iterator = get_iterator(data)
        for x, y, _ in iterator:
          if FLAGS.eval_language_model:
            is_present_rate = 0.
          else:
            is_present_rate = FLAGS.is_present_rate
          tf.logging.info(
              'Evaluating on is_present_rate=%.3f.' % is_present_rate)

          model_utils.assign_percent_real(sess, model.percent_real_update,
                                          model.new_rate, is_present_rate)

          # Randomly mask out tokens.
          p = model_utils.generate_mask()

          eval_feed = {model.inputs: x, model.targets: y, model.present: p}

          if FLAGS.data_set == 'ptb':
            # Statefulness for *evaluation* Generator.
            for i, (c, h) in enumerate(model.eval_initial_state):
              eval_feed[c] = gen_initial_state_eval[i].c
              eval_feed[h] = gen_initial_state_eval[i].h

            # Statefulness for the Generator.
            for i, (c, h) in enumerate(model.fake_gen_initial_state):
              eval_feed[c] = fake_gen_initial_state_eval[i].c
              eval_feed[h] = fake_gen_initial_state_eval[i].h

          [gen_initial_state_eval, fake_gen_initial_state_eval, _] = sess.run(
              [
                  model.eval_final_state, model.fake_gen_final_state,
                  model.global_step
              ],
              feed_dict=eval_feed)

          generate_logs(sess, model, output_file, id_to_word, eval_feed)
      output_file.close()
      print('Closing output_file.')
      return
Example #3
0
def pretrain_discriminator(sv, sess, model, data, log, id_to_word,
                           data_ngram_counts, is_chief):
    print('\nPretraining discriminator for %d steps.' % FLAGS.dis_pretrain_steps)
    log.write(
        '\nPretraining discriminator for %d steps.\n' % FLAGS.dis_pretrain_steps)

    is_pretraining = True

    while is_pretraining:

        cumulative_costs = 0.
        iters = 0
        if FLAGS.data_set == 'ptb':
            iterator = ptb_loader.ptb_iterator(data, FLAGS.batch_size,
                                               FLAGS.sequence_length,
                                               FLAGS.epoch_size_override)
        elif FLAGS.data_set == 'imdb':
            iterator = imdb_loader.imdb_iterator(data, FLAGS.batch_size,
                                                 FLAGS.sequence_length)

        for x, y, _ in iterator:
            is_present_rate = FLAGS.is_present_rate
            # is_present_rate = np.random.uniform(low=0.0, high=1.0)
            model_utils.assign_percent_real(sess, model.percent_real_update,
                                            model.new_rate, is_present_rate)
            # Randomly mask out tokens.
            p = model_utils.generate_mask()

            pretrain_feed = {model.inputs: x, model.targets: y, model.present: p}

            [_, dis_loss_eval, gen_log_perplexity_eval, step] = sess.run(
                [
                    model.dis_pretrain_op, model.dis_loss, model.avg_log_perplexity,
                    model.global_step
                ],
                feed_dict=pretrain_feed)

            cumulative_costs += gen_log_perplexity_eval
            iters += 1

            # Calulate rolling perplexity.
            perplexity = np.exp(cumulative_costs / iters)

            # Summaries.
            if is_chief and step % FLAGS.summaries_every == 0:
                # Graph summaries.
                summary_str = sess.run(
                    model.merge_summaries_op, feed_dict=pretrain_feed)
                sv.SummaryComputed(sess, summary_str)

                # Additional summary.
                for n, data_ngram_count in data_ngram_counts.iteritems():
                    avg_percent_captured = evaluation_utils.sequence_ngram_evaluation(
                        sess, model.fake_sequence, log, pretrain_feed, data_ngram_count,
                        int(n))
                    summary_percent_str = tf.Summary(value=[
                        tf.Summary.Value(
                            tag='general/%s-grams_percent_correct' % n,
                            simple_value=avg_percent_captured)
                    ])
                    sv.SummaryComputed(sess, summary_percent_str, global_step=step)

                summary_perplexity_str = tf.Summary(value=[
                    tf.Summary.Value(tag='general/perplexity', simple_value=perplexity)
                ])
                sv.SummaryComputed(sess, summary_perplexity_str, global_step=step)

            # Printing and logging
            if is_chief and step % FLAGS.print_every == 0:
                print('global_step: %d' % step)
                print(' discriminator loss: %.3f' % dis_loss_eval)
                print(' perplexity: %.3f' % perplexity)
                log.write('global_step: %d\n' % step)
                log.write(' discriminator loss: %.3f\n' % dis_loss_eval)
                log.write(' perplexity: %.3f\n' % perplexity)

                for n, data_ngram_count in data_ngram_counts.iteritems():
                    avg_percent_captured = evaluation_utils.sequence_ngram_evaluation(
                        sess, model.fake_sequence, log, pretrain_feed, data_ngram_count,
                        int(n))
                    print(' percent of %s-grams captured: %.3f.\n' %
                          (n, avg_percent_captured))
                    log.write(' percent of %s-grams captured: %.3f.\n\n' %
                              (n, avg_percent_captured))

                evaluation_utils.generate_logs(sess, model, log, id_to_word,
                                               pretrain_feed)

            if step >= FLAGS.dis_pretrain_steps + int(FLAGS.gen_pretrain_steps or 0):
                is_pretraining = False
                break
    return
def evaluate_once(data, sv, model, sess, train_dir, log, id_to_word,
                  data_ngram_counts, eval_saver):
  """Evaluate model for a number of steps.

  Args:
    data:  Dataset.
    sv: Supervisor.
    model: The GAN model we have just built.
    sess: A session to use.
    train_dir: Path to a directory containing checkpoints.
    log: Evaluation log for evaluation.
    id_to_word: Dictionary of indices to words.
    data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
      data_set.
    eval_saver:  Evaluation saver.r.
  """
  tf.logging.info('Evaluate Once.')
  # Load the last model checkpoint, or initialize the graph.
  model_save_path = tf.train.latest_checkpoint(train_dir)
  if not model_save_path:
    tf.logging.warning('No checkpoint yet in: %s', train_dir)
    return

  tf.logging.info('Starting eval of: %s' % model_save_path)
  tf.logging.info('Only restoring trainable variables.')
  eval_saver.restore(sess, model_save_path)

  # Run the requested number of evaluation steps
  avg_epoch_gen_loss, avg_epoch_dis_loss = [], []
  cumulative_costs = 0.

  # Average percent captured for each of the n-grams.
  avg_percent_captured = {'2': 0., '3': 0., '4': 0.}

  # Set a random seed to keep fixed mask.
  np.random.seed(0)
  gen_iters = 0

  # Generator statefulness over the epoch.
  # TODO(liamfedus):  Check this.
  [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
      [model.eval_initial_state, model.fake_gen_initial_state])

  if FLAGS.eval_language_model:
    is_present_rate = 0.
    tf.logging.info('Overriding is_present_rate=0. for evaluation.')
    print('Overriding is_present_rate=0. for evaluation.')

  iterator = get_iterator(data)

  for x, y, _ in iterator:
    if FLAGS.eval_language_model:
      is_present_rate = 0.
    else:
      is_present_rate = FLAGS.is_present_rate
      tf.logging.info('Evaluating on is_present_rate=%.3f.' % is_present_rate)

    model_utils.assign_percent_real(sess, model.percent_real_update,
                                    model.new_rate, is_present_rate)

    # Randomly mask out tokens.
    p = model_utils.generate_mask()

    eval_feed = {model.inputs: x, model.targets: y, model.present: p}

    if FLAGS.data_set == 'ptb':
      # Statefulness for *evaluation* Generator.
      for i, (c, h) in enumerate(model.eval_initial_state):
        eval_feed[c] = gen_initial_state_eval[i].c
        eval_feed[h] = gen_initial_state_eval[i].h

      # Statefulness for the Generator.
      for i, (c, h) in enumerate(model.fake_gen_initial_state):
        eval_feed[c] = fake_gen_initial_state_eval[i].c
        eval_feed[h] = fake_gen_initial_state_eval[i].h

    [
        gen_log_perplexity_eval, dis_loss_eval, gen_loss_eval,
        gen_initial_state_eval, fake_gen_initial_state_eval, step
    ] = sess.run(
        [
            model.avg_log_perplexity, model.dis_loss, model.gen_loss,
            model.eval_final_state, model.fake_gen_final_state,
            model.global_step
        ],
        feed_dict=eval_feed)

    for n, data_ngram_count in data_ngram_counts.items():
      batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
          sess, model.fake_sequence, log, eval_feed, data_ngram_count, int(n))
      avg_percent_captured[n] += batch_percent_captured

    cumulative_costs += gen_log_perplexity_eval

    avg_epoch_dis_loss.append(dis_loss_eval)
    avg_epoch_gen_loss.append(gen_loss_eval)

    gen_iters += 1

  # Calulate rolling metrics.
  perplexity = np.exp(cumulative_costs / gen_iters)
  for n, _ in avg_percent_captured.items():
    avg_percent_captured[n] /= gen_iters

  # Confirm perplexity is not infinite.
  if not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold:
    print('Evaluation raising FloatingPointError.')
    raise FloatingPointError(
        'Evaluation infinite perplexity: %.3f' % perplexity)

  ## Printing and logging.
  evaluation_utils.print_and_log_losses(log, step, is_present_rate,
                                        avg_epoch_dis_loss, avg_epoch_gen_loss)
  print(' perplexity: %.3f' % perplexity)
  log.write(' perplexity: %.3f\n' % perplexity)

  for n, n_gram_percent in avg_percent_captured.items():
    n = int(n)
    print(' percent of %d-grams captured: %.3f.' % (n, n_gram_percent))
    log.write(' percent of %d-grams captured: %.3f.\n' % (n, n_gram_percent))

  samples = evaluation_utils.generate_logs(sess, model, log, id_to_word,
                                           eval_feed)

  ## Summaries.
  summary_str = sess.run(model.merge_summaries_op, feed_dict=eval_feed)
  sv.SummaryComputed(sess, summary_str)

  # Summary: text
  summary_str = sess.run(model.text_summary_op,
                         {model.text_summary_placeholder: '\n\n'.join(samples)})
  sv.SummaryComputed(sess, summary_str, global_step=step)

  # Summary:  n-gram
  for n, n_gram_percent in avg_percent_captured.items():
    n = int(n)
    summary_percent_str = tf.Summary(value=[
        tf.Summary.Value(
            tag='general/%d-grams_percent_correct' % n,
            simple_value=n_gram_percent)
    ])
    sv.SummaryComputed(sess, summary_percent_str, global_step=step)

  # Summary:  geometric_avg
  geometric_avg = compute_geometric_average(avg_percent_captured)
  summary_geometric_avg_str = tf.Summary(value=[
      tf.Summary.Value(tag='general/geometric_avg', simple_value=geometric_avg)
  ])
  sv.SummaryComputed(sess, summary_geometric_avg_str, global_step=step)

  # Summary:  arithmetic_avg
  arithmetic_avg = compute_arithmetic_average(avg_percent_captured)
  summary_arithmetic_avg_str = tf.Summary(value=[
      tf.Summary.Value(
          tag='general/arithmetic_avg', simple_value=arithmetic_avg)
  ])
  sv.SummaryComputed(sess, summary_arithmetic_avg_str, global_step=step)

  # Summary:  perplexity
  summary_perplexity_str = tf.Summary(value=[
      tf.Summary.Value(tag='general/perplexity', simple_value=perplexity)
  ])
  sv.SummaryComputed(sess, summary_perplexity_str, global_step=step)
def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts):
  """Train model.

  Args:
    hparams: Hyperparameters for the MaskGAN.
    data: Data to evaluate.
    log_dir: Directory to save checkpoints.
    log: Readable log for the experiment.
    id_to_word: Dictionary of indices to words.
    data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
      data_set.
  """
  print('Training model.')
  tf.logging.info('Training model.')

  # Boolean indicating operational mode.
  is_training = True

  # Write all the information to the logs.
  log.write('hparams\n')
  log.write(str(hparams))
  log.flush()

  is_chief = FLAGS.task == 0

  with tf.Graph().as_default():
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      container_name = ''
      with tf.container(container_name):
        # Construct the model.
        if FLAGS.num_rollouts == 1:
          model = create_MaskGAN(hparams, is_training)
        elif FLAGS.num_rollouts > 1:
          model = rollout.create_rollout_MaskGAN(hparams, is_training)
        else:
          raise ValueError

        print('\nTrainable Variables in Graph:')
        for v in tf.trainable_variables():
          print(v)

        ## Retrieve the initial savers.
        init_savers = model_utils.retrieve_init_savers(hparams)

        ## Initial saver function to supervisor.
        init_fn = partial(model_utils.init_fn, init_savers)

        # Create the supervisor.  It will take care of initialization,
        # summaries, checkpoints, and recovery.
        sv = tf.train.Supervisor(
            logdir=log_dir,
            is_chief=is_chief,
            saver=model.saver,
            global_step=model.global_step,
            save_model_secs=60,
            recovery_wait_secs=30,
            summary_op=None,
            init_fn=init_fn)

        # Get an initialized, and possibly recovered session.  Launch the
        # services: Checkpointing, Summaries, step counting.
        #
        # When multiple replicas of this program are running the services are
        # only launched by the 'chief' replica.
        with sv.managed_session(FLAGS.master) as sess:

          ## Pretrain the generator.
          if FLAGS.gen_pretrain_steps:
            pretrain_mask_gan.pretrain_generator(sv, sess, model, data, log,
                                                 id_to_word, data_ngram_counts,
                                                 is_chief)

          ## Pretrain the discriminator.
          if FLAGS.dis_pretrain_steps:
            pretrain_mask_gan.pretrain_discriminator(
                sv, sess, model, data, log, id_to_word, data_ngram_counts,
                is_chief)

          # Initial indicators for printing and summarizing.
          print_step_division = -1
          summary_step_division = -1

          # Run iterative computation in a loop.
          while not sv.ShouldStop():
            is_present_rate = FLAGS.is_present_rate

            if FLAGS.is_present_rate_decay is not None:
              is_present_rate *= (1. - FLAGS.is_present_rate_decay)

            model_utils.assign_percent_real(sess, model.percent_real_update,
                                            model.new_rate, is_present_rate)

            # GAN training.
            avg_epoch_gen_loss, avg_epoch_dis_loss = [], []
            cumulative_costs = 0.
            gen_iters = 0

            # Generator and Discriminator statefulness initial evaluation.
            # TODO(liamfedus): Throughout the code I am implicitly assuming
            # that the Generator and Discriminator are equal sized.
            [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
                [model.eval_initial_state, model.fake_gen_initial_state])
            dis_initial_state_eval = fake_gen_initial_state_eval

            # Save zeros state to reset later.
            zeros_state = fake_gen_initial_state_eval

            ## Offset Discriminator.
            if FLAGS.ps_tasks == 0:
              dis_offset = 1
            else:
              dis_offset = FLAGS.task * 1000 + 1
            dis_iterator = get_iterator(data)

            for i in range(dis_offset):
              try:
                dis_x, dis_y, _ = next(dis_iterator)
              except StopIteration:
                dis_iterator = get_iterator(data)
                dis_initial_state_eval = zeros_state
                dis_x, dis_y, _ = next(dis_iterator)

              p = model_utils.generate_mask()

              # Construct the train feed.
              train_feed = {
                  model.inputs: dis_x,
                  model.targets: dis_y,
                  model.present: p
              }

              if FLAGS.data_set == 'ptb':
                # Statefulness of the Generator being used for Discriminator.
                for i, (c, h) in enumerate(model.fake_gen_initial_state):
                  train_feed[c] = dis_initial_state_eval[i].c
                  train_feed[h] = dis_initial_state_eval[i].h

                # Determine the state had the Generator run over real data.  We
                # use this state for the Discriminator.
                [dis_initial_state_eval] = sess.run(
                    [model.fake_gen_final_state], train_feed)

            ## Training loop.
            iterator = get_iterator(data)
            gen_initial_state_eval = zeros_state

            if FLAGS.ps_tasks > 0:
              gen_offset = FLAGS.task * 1000 + 1
              for i in range(gen_offset):
                try:
                  next(iterator)
                except StopIteration:
                  dis_iterator = get_iterator(data)
                  dis_initial_state_eval = zeros_state
                  next(dis_iterator)

            for x, y, _ in iterator:
              for _ in xrange(hparams.dis_train_iterations):
                try:
                  dis_x, dis_y, _ = next(dis_iterator)
                except StopIteration:
                  dis_iterator = get_iterator(data)
                  dis_initial_state_eval = zeros_state
                  dis_x, dis_y, _ = next(dis_iterator)

                  if FLAGS.data_set == 'ptb':
                    [dis_initial_state_eval] = sess.run(
                        [model.fake_gen_initial_state])

                p = model_utils.generate_mask()

                # Construct the train feed.
                train_feed = {
                    model.inputs: dis_x,
                    model.targets: dis_y,
                    model.present: p
                }

                # Statefulness for the Discriminator.
                if FLAGS.data_set == 'ptb':
                  for i, (c, h) in enumerate(model.fake_gen_initial_state):
                    train_feed[c] = dis_initial_state_eval[i].c
                    train_feed[h] = dis_initial_state_eval[i].h

                _, dis_loss_eval, step = sess.run(
                    [model.dis_train_op, model.dis_loss, model.global_step],
                    feed_dict=train_feed)

                # Determine the state had the Generator run over real data.
                # Use this state for the Discriminator.
                [dis_initial_state_eval] = sess.run(
                    [model.fake_gen_final_state], train_feed)

              # Randomly mask out tokens.
              p = model_utils.generate_mask()

              # Construct the train feed.
              train_feed = {model.inputs: x, model.targets: y, model.present: p}

              # Statefulness for Generator.
              if FLAGS.data_set == 'ptb':
                tf.logging.info('Generator is stateful.')
                print('Generator is stateful.')
                # Statefulness for *evaluation* Generator.
                for i, (c, h) in enumerate(model.eval_initial_state):
                  train_feed[c] = gen_initial_state_eval[i].c
                  train_feed[h] = gen_initial_state_eval[i].h

                # Statefulness for Generator.
                for i, (c, h) in enumerate(model.fake_gen_initial_state):
                  train_feed[c] = fake_gen_initial_state_eval[i].c
                  train_feed[h] = fake_gen_initial_state_eval[i].h

              # Determine whether to decay learning rate.
              lr_decay = hparams.gen_learning_rate_decay**max(
                  step + 1 - hparams.gen_full_learning_rate_steps, 0.0)

              # Assign learning rate.
              gen_learning_rate = hparams.gen_learning_rate * lr_decay
              model_utils.assign_learning_rate(sess, model.learning_rate_update,
                                               model.new_learning_rate,
                                               gen_learning_rate)

              [_, gen_loss_eval, gen_log_perplexity_eval, step] = sess.run(
                  [
                      model.gen_train_op, model.gen_loss,
                      model.avg_log_perplexity, model.global_step
                  ],
                  feed_dict=train_feed)

              cumulative_costs += gen_log_perplexity_eval
              gen_iters += 1

              # Determine the state had the Generator run over real data.
              [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
                  [model.eval_final_state,
                   model.fake_gen_final_state], train_feed)

              avg_epoch_dis_loss.append(dis_loss_eval)
              avg_epoch_gen_loss.append(gen_loss_eval)

              ## Summaries.
              # Calulate rolling perplexity.
              perplexity = np.exp(cumulative_costs / gen_iters)

              if is_chief and (step / FLAGS.summaries_every >
                               summary_step_division):
                summary_step_division = step / FLAGS.summaries_every

                # Confirm perplexity is not infinite.
                if (not np.isfinite(perplexity) or
                    perplexity >= FLAGS.perplexity_threshold):
                  print('Training raising FloatingPoinError.')
                  raise FloatingPointError(
                      'Training infinite perplexity: %.3f' % perplexity)

                # Graph summaries.
                summary_str = sess.run(
                    model.merge_summaries_op, feed_dict=train_feed)
                sv.SummaryComputed(sess, summary_str)

                # Summary:  n-gram
                avg_percent_captured = {'2': 0., '3': 0., '4': 0.}
                for n, data_ngram_count in data_ngram_counts.items():
                  batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
                      sess, model.fake_sequence, log, train_feed,
                      data_ngram_count, int(n))
                  summary_percent_str = tf.Summary(value=[
                      tf.Summary.Value(
                          tag='general/%s-grams_percent_correct' % n,
                          simple_value=batch_percent_captured)
                  ])
                  sv.SummaryComputed(
                      sess, summary_percent_str, global_step=step)

                # Summary:  geometric_avg
                geometric_avg = compute_geometric_average(avg_percent_captured)
                summary_geometric_avg_str = tf.Summary(value=[
                    tf.Summary.Value(
                        tag='general/geometric_avg', simple_value=geometric_avg)
                ])
                sv.SummaryComputed(
                    sess, summary_geometric_avg_str, global_step=step)

                # Summary:  arithmetic_avg
                arithmetic_avg = compute_arithmetic_average(
                    avg_percent_captured)
                summary_arithmetic_avg_str = tf.Summary(value=[
                    tf.Summary.Value(
                        tag='general/arithmetic_avg',
                        simple_value=arithmetic_avg)
                ])
                sv.SummaryComputed(
                    sess, summary_arithmetic_avg_str, global_step=step)

                # Summary:  perplexity
                summary_perplexity_str = tf.Summary(value=[
                    tf.Summary.Value(
                        tag='general/perplexity', simple_value=perplexity)
                ])
                sv.SummaryComputed(
                    sess, summary_perplexity_str, global_step=step)

              ## Printing and logging
              if is_chief and (step / FLAGS.print_every > print_step_division):
                print_step_division = (step / FLAGS.print_every)
                print('global_step: %d' % step)
                print(' perplexity: %.3f' % perplexity)
                print(' gen_learning_rate: %.6f' % gen_learning_rate)
                log.write('global_step: %d\n' % step)
                log.write(' perplexity: %.3f\n' % perplexity)
                log.write(' gen_learning_rate: %.6f' % gen_learning_rate)

                # Average percent captured for each of the n-grams.
                avg_percent_captured = {'2': 0., '3': 0., '4': 0.}
                for n, data_ngram_count in data_ngram_counts.items():
                  batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
                      sess, model.fake_sequence, log, train_feed,
                      data_ngram_count, int(n))
                  avg_percent_captured[n] = batch_percent_captured
                  print(' percent of %s-grams captured: %.3f.' %
                        (n, batch_percent_captured))
                  log.write(' percent of %s-grams captured: %.3f.\n' %
                            (n, batch_percent_captured))
                geometric_avg = compute_geometric_average(avg_percent_captured)
                print(' geometric_avg: %.3f.' % geometric_avg)
                log.write(' geometric_avg: %.3f.' % geometric_avg)
                arithmetic_avg = compute_arithmetic_average(
                    avg_percent_captured)
                print(' arithmetic_avg: %.3f.' % arithmetic_avg)
                log.write(' arithmetic_avg: %.3f.' % arithmetic_avg)

                evaluation_utils.print_and_log_losses(
                    log, step, is_present_rate, avg_epoch_dis_loss,
                    avg_epoch_gen_loss)

                if FLAGS.gen_training_strategy == 'reinforce':
                  evaluation_utils.generate_RL_logs(sess, model, log,
                                                    id_to_word, train_feed)
                else:
                  evaluation_utils.generate_logs(sess, model, log, id_to_word,
                                                 train_feed)
                log.flush()

  log.close()
Example #6
0
def evaluate_once(data, sv, model, sess, train_dir, log, id_to_word,
                  data_ngram_counts, eval_saver):
  """Evaluate model for a number of steps.

  Args:
    data:  Dataset.
    sv: Supervisor.
    model: The GAN model we have just built.
    sess: A session to use.
    train_dir: Path to a directory containing checkpoints.
    log: Evaluation log for evaluation.
    id_to_word: Dictionary of indices to words.
    data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
      data_set.
    eval_saver:  Evaluation saver.r.
  """
  tf.logging.info('Evaluate Once.')
  # Load the last model checkpoint, or initialize the graph.
  model_save_path = tf.latest_checkpoint(train_dir)
  if not model_save_path:
    tf.logging.warning('No checkpoint yet in: %s', train_dir)
    return

  tf.logging.info('Starting eval of: %s' % model_save_path)
  tf.logging.info('Only restoring trainable variables.')
  eval_saver.restore(sess, model_save_path)

  # Run the requested number of evaluation steps
  avg_epoch_gen_loss, avg_epoch_dis_loss = [], []
  cumulative_costs = 0.

  # Average percent captured for each of the n-grams.
  avg_percent_captured = {'2': 0., '3': 0., '4': 0.}

  # Set a random seed to keep fixed mask.
  np.random.seed(0)
  gen_iters = 0

  # Generator statefulness over the epoch.
  # TODO(liamfedus):  Check this.
  [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
      [model.eval_initial_state, model.fake_gen_initial_state])

  if FLAGS.eval_language_model:
    is_present_rate = 0.
    tf.logging.info('Overriding is_present_rate=0. for evaluation.')
    print('Overriding is_present_rate=0. for evaluation.')

  iterator = get_iterator(data)

  for x, y, _ in iterator:
    if FLAGS.eval_language_model:
      is_present_rate = 0.
    else:
      is_present_rate = FLAGS.is_present_rate
      tf.logging.info('Evaluating on is_present_rate=%.3f.' % is_present_rate)

    model_utils.assign_percent_real(sess, model.percent_real_update,
                                    model.new_rate, is_present_rate)

    # Randomly mask out tokens.
    p = model_utils.generate_mask()

    eval_feed = {model.inputs: x, model.targets: y, model.present: p}

    if FLAGS.data_set == 'ptb':
      # Statefulness for *evaluation* Generator.
      for i, (c, h) in enumerate(model.eval_initial_state):
        eval_feed[c] = gen_initial_state_eval[i].c
        eval_feed[h] = gen_initial_state_eval[i].h

      # Statefulness for the Generator.
      for i, (c, h) in enumerate(model.fake_gen_initial_state):
        eval_feed[c] = fake_gen_initial_state_eval[i].c
        eval_feed[h] = fake_gen_initial_state_eval[i].h

    [
        gen_log_perplexity_eval, dis_loss_eval, gen_loss_eval,
        gen_initial_state_eval, fake_gen_initial_state_eval, step
    ] = sess.run(
        [
            model.avg_log_perplexity, model.dis_loss, model.gen_loss,
            model.eval_final_state, model.fake_gen_final_state,
            model.global_step
        ],
        feed_dict=eval_feed)

    for n, data_ngram_count in data_ngram_counts.iteritems():
      batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
          sess, model.fake_sequence, log, eval_feed, data_ngram_count, int(n))
      avg_percent_captured[n] += batch_percent_captured

    cumulative_costs += gen_log_perplexity_eval

    avg_epoch_dis_loss.append(dis_loss_eval)
    avg_epoch_gen_loss.append(gen_loss_eval)

    gen_iters += 1

  # Calulate rolling metrics.
  perplexity = np.exp(cumulative_costs / gen_iters)
  for n, _ in avg_percent_captured.iteritems():
    avg_percent_captured[n] /= gen_iters

  # Confirm perplexity is not infinite.
  if not np.isfinite(perplexity) or perplexity >= FLAGS.perplexity_threshold:
    print('Evaluation raising FloatingPointError.')
    raise FloatingPointError(
        'Evaluation infinite perplexity: %.3f' % perplexity)

  ## Printing and logging.
  evaluation_utils.print_and_log_losses(log, step, is_present_rate,
                                        avg_epoch_dis_loss, avg_epoch_gen_loss)
  print(' perplexity: %.3f' % perplexity)
  log.write(' perplexity: %.3f\n' % perplexity)

  for n, n_gram_percent in avg_percent_captured.iteritems():
    n = int(n)
    print(' percent of %d-grams captured: %.3f.' % (n, n_gram_percent))
    log.write(' percent of %d-grams captured: %.3f.\n' % (n, n_gram_percent))

  samples = evaluation_utils.generate_logs(sess, model, log, id_to_word,
                                           eval_feed)

  ## Summaries.
  summary_str = sess.run(model.merge_summaries_op, feed_dict=eval_feed)
  sv.SummaryComputed(sess, summary_str)

  # Summary: text
  summary_str = sess.run(model.text_summary_op,
                         {model.text_summary_placeholder: '\n\n'.join(samples)})
  sv.SummaryComputed(sess, summary_str, global_step=step)

  # Summary:  n-gram
  for n, n_gram_percent in avg_percent_captured.iteritems():
    n = int(n)
    summary_percent_str = tf.Summary(value=[
        tf.Summary.Value(
            tag='general/%d-grams_percent_correct' % n,
            simple_value=n_gram_percent)
    ])
    sv.SummaryComputed(sess, summary_percent_str, global_step=step)

  # Summary:  geometric_avg
  geometric_avg = compute_geometric_average(avg_percent_captured)
  summary_geometric_avg_str = tf.Summary(value=[
      tf.Summary.Value(tag='general/geometric_avg', simple_value=geometric_avg)
  ])
  sv.SummaryComputed(sess, summary_geometric_avg_str, global_step=step)

  # Summary:  arithmetic_avg
  arithmetic_avg = compute_arithmetic_average(avg_percent_captured)
  summary_arithmetic_avg_str = tf.Summary(value=[
      tf.Summary.Value(
          tag='general/arithmetic_avg', simple_value=arithmetic_avg)
  ])
  sv.SummaryComputed(sess, summary_arithmetic_avg_str, global_step=step)

  # Summary:  perplexity
  summary_perplexity_str = tf.Summary(value=[
      tf.Summary.Value(tag='general/perplexity', simple_value=perplexity)
  ])
  sv.SummaryComputed(sess, summary_perplexity_str, global_step=step)
Example #7
0
def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts):
  """Train model.

  Args:
    hparams: Hyperparameters for the MaskGAN.
    data: Data to evaluate.
    log_dir: Directory to save checkpoints.
    log: Readable log for the experiment.
    id_to_word: Dictionary of indices to words.
    data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
      data_set.
  """
  print('Training model.')
  tf.logging.info('Training model.')

  # Boolean indicating operational mode.
  is_training = True

  # Write all the information to the logs.
  log.write('hparams\n')
  log.write(str(hparams))
  log.flush()

  is_chief = FLAGS.task == 0

  with tf.Graph().as_default():
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      container_name = ''
      with tf.container(container_name):
        # Construct the model.
        if FLAGS.num_rollouts == 1:
          model = create_MaskGAN(hparams, is_training)
        elif FLAGS.num_rollouts > 1:
          model = rollout.create_rollout_MaskGAN(hparams, is_training)
        else:
          raise ValueError

        print('\nTrainable Variables in Graph:')
        for v in tf.trainable_variables():
          print(v)

        ## Retrieve the initial savers.
        init_savers = model_utils.retrieve_init_savers(hparams)

        ## Initial saver function to supervisor.
        init_fn = partial(model_utils.init_fn, init_savers)

        # Create the supervisor.  It will take care of initialization,
        # summaries, checkpoints, and recovery.
        sv = tf.train.Supervisor(
            logdir=log_dir,
            is_chief=is_chief,
            saver=model.saver,
            global_step=model.global_step,
            save_model_secs=60,
            recovery_wait_secs=30,
            summary_op=None,
            init_fn=init_fn)

        # Get an initialized, and possibly recovered session.  Launch the
        # services: Checkpointing, Summaries, step counting.
        #
        # When multiple replicas of this program are running the services are
        # only launched by the 'chief' replica.
        with sv.managed_session(FLAGS.master) as sess:

          ## Pretrain the generator.
          if FLAGS.gen_pretrain_steps:
            pretrain_mask_gan.pretrain_generator(sv, sess, model, data, log,
                                                 id_to_word, data_ngram_counts,
                                                 is_chief)

          ## Pretrain the discriminator.
          if FLAGS.dis_pretrain_steps:
            pretrain_mask_gan.pretrain_discriminator(
                sv, sess, model, data, log, id_to_word, data_ngram_counts,
                is_chief)

          # Initial indicators for printing and summarizing.
          print_step_division = -1
          summary_step_division = -1

          # Run iterative computation in a loop.
          while not sv.ShouldStop():
            is_present_rate = FLAGS.is_present_rate

            if FLAGS.is_present_rate_decay is not None:
              is_present_rate *= (1. - FLAGS.is_present_rate_decay)

            model_utils.assign_percent_real(sess, model.percent_real_update,
                                            model.new_rate, is_present_rate)

            # GAN training.
            avg_epoch_gen_loss, avg_epoch_dis_loss = [], []
            cumulative_costs = 0.
            gen_iters = 0

            # Generator and Discriminator statefulness initial evaluation.
            # TODO(liamfedus): Throughout the code I am implicitly assuming
            # that the Generator and Discriminator are equal sized.
            [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
                [model.eval_initial_state, model.fake_gen_initial_state])
            dis_initial_state_eval = fake_gen_initial_state_eval

            # Save zeros state to reset later.
            zeros_state = fake_gen_initial_state_eval

            ## Offset Discriminator.
            if FLAGS.ps_tasks == 0:
              dis_offset = 1
            else:
              dis_offset = FLAGS.task * 1000 + 1
            dis_iterator = get_iterator(data)

            for i in range(dis_offset):
              try:
                dis_x, dis_y, _ = next(dis_iterator)
              except StopIteration:
                dis_iterator = get_iterator(data)
                dis_initial_state_eval = zeros_state
                dis_x, dis_y, _ = next(dis_iterator)

              p = model_utils.generate_mask()

              # Construct the train feed.
              train_feed = {
                  model.inputs: dis_x,
                  model.targets: dis_y,
                  model.present: p
              }

              if FLAGS.data_set == 'ptb':
                # Statefulness of the Generator being used for Discriminator.
                for i, (c, h) in enumerate(model.fake_gen_initial_state):
                  train_feed[c] = dis_initial_state_eval[i].c
                  train_feed[h] = dis_initial_state_eval[i].h

                # Determine the state had the Generator run over real data.  We
                # use this state for the Discriminator.
                [dis_initial_state_eval] = sess.run(
                    [model.fake_gen_final_state], train_feed)

            ## Training loop.
            iterator = get_iterator(data)
            gen_initial_state_eval = zeros_state

            if FLAGS.ps_tasks > 0:
              gen_offset = FLAGS.task * 1000 + 1
              for i in range(gen_offset):
                try:
                  next(iterator)
                except StopIteration:
                  dis_iterator = get_iterator(data)
                  dis_initial_state_eval = zeros_state
                  next(dis_iterator)

            for x, y, _ in iterator:
              for _ in xrange(hparams.dis_train_iterations):
                try:
                  dis_x, dis_y, _ = next(dis_iterator)
                except StopIteration:
                  dis_iterator = get_iterator(data)
                  dis_initial_state_eval = zeros_state
                  dis_x, dis_y, _ = next(dis_iterator)

                  if FLAGS.data_set == 'ptb':
                    [dis_initial_state_eval] = sess.run(
                        [model.fake_gen_initial_state])

                p = model_utils.generate_mask()

                # Construct the train feed.
                train_feed = {
                    model.inputs: dis_x,
                    model.targets: dis_y,
                    model.present: p
                }

                # Statefulness for the Discriminator.
                if FLAGS.data_set == 'ptb':
                  for i, (c, h) in enumerate(model.fake_gen_initial_state):
                    train_feed[c] = dis_initial_state_eval[i].c
                    train_feed[h] = dis_initial_state_eval[i].h

                _, dis_loss_eval, step = sess.run(
                    [model.dis_train_op, model.dis_loss, model.global_step],
                    feed_dict=train_feed)

                # Determine the state had the Generator run over real data.
                # Use this state for the Discriminator.
                [dis_initial_state_eval] = sess.run(
                    [model.fake_gen_final_state], train_feed)

              # Randomly mask out tokens.
              p = model_utils.generate_mask()

              # Construct the train feed.
              train_feed = {model.inputs: x, model.targets: y, model.present: p}

              # Statefulness for Generator.
              if FLAGS.data_set == 'ptb':
                tf.logging.info('Generator is stateful.')
                print('Generator is stateful.')
                # Statefulness for *evaluation* Generator.
                for i, (c, h) in enumerate(model.eval_initial_state):
                  train_feed[c] = gen_initial_state_eval[i].c
                  train_feed[h] = gen_initial_state_eval[i].h

                # Statefulness for Generator.
                for i, (c, h) in enumerate(model.fake_gen_initial_state):
                  train_feed[c] = fake_gen_initial_state_eval[i].c
                  train_feed[h] = fake_gen_initial_state_eval[i].h

              # Determine whether to decay learning rate.
              lr_decay = hparams.gen_learning_rate_decay**max(
                  step + 1 - hparams.gen_full_learning_rate_steps, 0.0)

              # Assign learning rate.
              gen_learning_rate = hparams.gen_learning_rate * lr_decay
              model_utils.assign_learning_rate(sess, model.learning_rate_update,
                                               model.new_learning_rate,
                                               gen_learning_rate)

              [_, gen_loss_eval, gen_log_perplexity_eval, step] = sess.run(
                  [
                      model.gen_train_op, model.gen_loss,
                      model.avg_log_perplexity, model.global_step
                  ],
                  feed_dict=train_feed)

              cumulative_costs += gen_log_perplexity_eval
              gen_iters += 1

              # Determine the state had the Generator run over real data.
              [gen_initial_state_eval, fake_gen_initial_state_eval] = sess.run(
                  [model.eval_final_state,
                   model.fake_gen_final_state], train_feed)

              avg_epoch_dis_loss.append(dis_loss_eval)
              avg_epoch_gen_loss.append(gen_loss_eval)

              ## Summaries.
              # Calulate rolling perplexity.
              perplexity = np.exp(cumulative_costs / gen_iters)

              if is_chief and (step / FLAGS.summaries_every >
                               summary_step_division):
                summary_step_division = step / FLAGS.summaries_every

                # Confirm perplexity is not infinite.
                if (not np.isfinite(perplexity) or
                    perplexity >= FLAGS.perplexity_threshold):
                  print('Training raising FloatingPoinError.')
                  raise FloatingPointError(
                      'Training infinite perplexity: %.3f' % perplexity)

                # Graph summaries.
                summary_str = sess.run(
                    model.merge_summaries_op, feed_dict=train_feed)
                sv.SummaryComputed(sess, summary_str)

                # Summary:  n-gram
                avg_percent_captured = {'2': 0., '3': 0., '4': 0.}
                for n, data_ngram_count in data_ngram_counts.iteritems():
                  batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
                      sess, model.fake_sequence, log, train_feed,
                      data_ngram_count, int(n))
                  summary_percent_str = tf.Summary(value=[
                      tf.Summary.Value(
                          tag='general/%s-grams_percent_correct' % n,
                          simple_value=batch_percent_captured)
                  ])
                  sv.SummaryComputed(
                      sess, summary_percent_str, global_step=step)

                # Summary:  geometric_avg
                geometric_avg = compute_geometric_average(avg_percent_captured)
                summary_geometric_avg_str = tf.Summary(value=[
                    tf.Summary.Value(
                        tag='general/geometric_avg', simple_value=geometric_avg)
                ])
                sv.SummaryComputed(
                    sess, summary_geometric_avg_str, global_step=step)

                # Summary:  arithmetic_avg
                arithmetic_avg = compute_arithmetic_average(
                    avg_percent_captured)
                summary_arithmetic_avg_str = tf.Summary(value=[
                    tf.Summary.Value(
                        tag='general/arithmetic_avg',
                        simple_value=arithmetic_avg)
                ])
                sv.SummaryComputed(
                    sess, summary_arithmetic_avg_str, global_step=step)

                # Summary:  perplexity
                summary_perplexity_str = tf.Summary(value=[
                    tf.Summary.Value(
                        tag='general/perplexity', simple_value=perplexity)
                ])
                sv.SummaryComputed(
                    sess, summary_perplexity_str, global_step=step)

              ## Printing and logging
              if is_chief and (step / FLAGS.print_every > print_step_division):
                print_step_division = (step / FLAGS.print_every)
                print('global_step: %d' % step)
                print(' perplexity: %.3f' % perplexity)
                print(' gen_learning_rate: %.6f' % gen_learning_rate)
                log.write('global_step: %d\n' % step)
                log.write(' perplexity: %.3f\n' % perplexity)
                log.write(' gen_learning_rate: %.6f' % gen_learning_rate)

                # Average percent captured for each of the n-grams.
                avg_percent_captured = {'2': 0., '3': 0., '4': 0.}
                for n, data_ngram_count in data_ngram_counts.iteritems():
                  batch_percent_captured = evaluation_utils.sequence_ngram_evaluation(
                      sess, model.fake_sequence, log, train_feed,
                      data_ngram_count, int(n))
                  avg_percent_captured[n] = batch_percent_captured
                  print(' percent of %s-grams captured: %.3f.' %
                        (n, batch_percent_captured))
                  log.write(' percent of %s-grams captured: %.3f.\n' %
                            (n, batch_percent_captured))
                geometric_avg = compute_geometric_average(avg_percent_captured)
                print(' geometric_avg: %.3f.' % geometric_avg)
                log.write(' geometric_avg: %.3f.' % geometric_avg)
                arithmetic_avg = compute_arithmetic_average(
                    avg_percent_captured)
                print(' arithmetic_avg: %.3f.' % arithmetic_avg)
                log.write(' arithmetic_avg: %.3f.' % arithmetic_avg)

                evaluation_utils.print_and_log_losses(
                    log, step, is_present_rate, avg_epoch_dis_loss,
                    avg_epoch_gen_loss)

                if FLAGS.gen_training_strategy == 'reinforce':
                  evaluation_utils.generate_RL_logs(sess, model, log,
                                                    id_to_word, train_feed)
                else:
                  evaluation_utils.generate_logs(sess, model, log, id_to_word,
                                                 train_feed)
                log.flush()

  log.close()