Пример #1
0
def estimate_popstats(unused_sv, sess, m, dataset, unused_hparams):
  """Averages over mini batches for population statistics for batch norm."""
  print('Estimating population statistics...')
  tfbatchstats, tfpopstats = list(zip(*list(m.popstats_by_batchstat.items())))

  nepochs = 3
  nppopstats = [lib_util.AggregateMean('') for _ in tfpopstats]
  for _ in range(nepochs):
    batches = (
        dataset.get_featuremaps().batches(size=m.batch_size, shuffle=True))
    for unused_step, batch in enumerate(batches):
      feed_dict = batch.get_feed_dict(m.placeholders)
      npbatchstats = sess.run(tfbatchstats, feed_dict=feed_dict)
      for nppopstat, npbatchstat in zip(nppopstats, npbatchstats):
        nppopstat.add(npbatchstat)
  nppopstats = [nppopstat.mean for nppopstat in nppopstats]

  _print_popstat_info(tfpopstats, nppopstats)

  # Update tfpopstat variables.
  for unused_j, (tfpopstat, nppopstat) in enumerate(
      zip(tfpopstats, nppopstats)):
    tfpopstat.load(nppopstat)
Пример #2
0
def run_epoch(supervisor, sess, m, dataset, hparams, eval_op, experiment_type,
              epoch_count):
  """Runs an epoch of training or evaluate the model on given data."""
  # reduce variance in validation loss by fixing the seed
  data_seed = 123 if experiment_type == 'valid' else None
  with lib_util.numpy_seed(data_seed):
    batches = (
        dataset.get_featuremaps().batches(
            size=m.batch_size, shuffle=True, shuffle_rng=data_seed))

  losses = lib_util.AggregateMean('losses')
  losses_total = lib_util.AggregateMean('losses_total')
  losses_mask = lib_util.AggregateMean('losses_mask')
  losses_unmask = lib_util.AggregateMean('losses_unmask')

  start_time = time.time()
  for unused_step, batch in enumerate(batches):
    # Evaluate the graph and run back propagation.
    fetches = [
        m.loss, m.loss_total, m.loss_mask, m.loss_unmask, m.reduced_mask_size,
        m.reduced_unmask_size, m.learning_rate, eval_op
    ]
    feed_dict = batch.get_feed_dict(m.placeholders)
    (loss, loss_total, loss_mask, loss_unmask, reduced_mask_size,
     reduced_unmask_size, learning_rate, _) = sess.run(
         fetches, feed_dict=feed_dict)

    # Aggregate performances.
    losses_total.add(loss_total, 1)
    # Multiply the mean loss_mask by reduced_mask_size for aggregation as the
    # mask size could be different for every batch.
    losses_mask.add(loss_mask * reduced_mask_size, reduced_mask_size)
    losses_unmask.add(loss_unmask * reduced_unmask_size, reduced_unmask_size)

    if hparams.optimize_mask_only:
      losses.add(loss * reduced_mask_size, reduced_mask_size)
    else:
      losses.add(loss, 1)

  # Collect run statistics.
  run_stats = dict()
  run_stats['loss_mask'] = losses_mask.mean
  run_stats['loss_unmask'] = losses_unmask.mean
  run_stats['loss_total'] = losses_total.mean
  run_stats['loss'] = losses.mean
  if experiment_type == 'train':
    run_stats['learning_rate'] = float(learning_rate)

  # Make summaries.
  if FLAGS.log_progress:
    summaries = tf.Summary()
    for stat_name, stat in six.iteritems(run_stats):
      value = summaries.value.add()
      value.tag = '%s_%s' % (stat_name, experiment_type)
      value.simple_value = stat
    supervisor.summary_computed(sess, summaries, epoch_count)

  tf.logging.info(
      '%s, epoch %d: loss (mask): %.4f, loss (unmask): %.4f, '
      'loss (total): %.4f, log lr: %.4f, time taken: %.4f',
      experiment_type, epoch_count, run_stats['loss_mask'],
      run_stats['loss_unmask'], run_stats['loss_total'],
      np.log(run_stats['learning_rate']) if 'learning_rate' in run_stats else 0,
      time.time() - start_time)

  return run_stats['loss']