Ejemplo n.º 1
0
def _write_checkpoint_metrics(checkpoint_path, metrics_and_values, eval_dir):
  """Writes a JSON of metrics for checkpoint_path in eval_dir.

  This function writes out metrics to a JSON for a checkpoint into eval_dir. The
  exact path of this file will be computed with:

    `checkpoint_metrics_path(checkpoint_path, eval_dir)`

  and the values for metrics_and_values (a dict of strings => objects) written
  out as key: str(object) into a JSON file.

  Args:
    checkpoint_path: str; a path to the checkpoint we computed metrics on.
    metrics_and_values: dict[string,object]; a dictionary of key/value pairs
      containing our metrics. These will be converted to a JSON of key/string
      pairs and written out to disk.
    eval_dir: str; a path to a directory where we will write out our checkpoint
      metrics.
  """
  path = checkpoint_metrics_path(checkpoint_path, eval_dir)
  serializable = {k: str(v) for k, v in metrics_and_values.iteritems()}
  logging.info('Writing checkpoint metrics %s', path)
  try:
    with tf.gfile.GFile(path, 'w') as fout:
      json.dump(serializable, fout, sort_keys=True, indent=4)
  except:  # pylint: disable=bare-except
    # Note we have a bare exception here as as there's no clear TF base
    # exception to catch will cover all of the potential issues that might arise
    # trying to write our metrics to our metrics file.
    logging.warning('Failed to write checkpoint metrics to path %s', path)
Ejemplo n.º 2
0
 def build_inference_for_training(self):
   """Invokes depth and ego-motion networks and computes clouds if needed."""
   (self.image_stack, self.intrinsic_mat, self.intrinsic_mat_inv) = (
       self.reader.read_data())
   with tf.name_scope('egomotion_prediction'):
     self.egomotion, _ = nets.egomotion_net(self.image_stack, is_training=True,
                                            legacy_mode=self.legacy_mode)
   with tf.variable_scope('depth_prediction'):
     # Organized by ...[i][scale].  Note that the order is flipped in
     # variables in build_loss() below.
     self.disp = {}
     self.depth = {}
     if self.icp_weight > 0:
       self.cloud = {}
     for i in range(self.seq_length):
       image = self.image_stack[:, :, :, 3 * i:3 * (i + 1)]
       multiscale_disps_i, _ = nets.disp_net(image, is_training=True)
       multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
       self.disp[i] = multiscale_disps_i
       self.depth[i] = multiscale_depths_i
       if self.icp_weight > 0:
         multiscale_clouds_i = [
             project.get_cloud(d,
                               self.intrinsic_mat_inv[:, s, :, :],
                               name='cloud%d_%d' % (s, i))
             for (s, d) in enumerate(multiscale_depths_i)
         ]
         self.cloud[i] = multiscale_clouds_i
       # Reuse the same depth graph for all images.
       tf.get_variable_scope().reuse_variables()
   logging.info('disp: %s', util.info(self.disp))
Ejemplo n.º 3
0
  def incremental_save(self, log_info=False):
    """Write new entries to disk.

    This performs an append operation on the `save_file` given in the
    constructor. Any entries added since the last call to `incremental_save`
    will be appended to the file.

    If a new RouletteWheel is constructed with the same `save_file`, all the
    entries written there will be automatically loaded into the instance.
    This is useful when a job resumes after preemption.

    Args:
      log_info: If True, info about this operation will be logged.

    Raises:
      RuntimeError: If `save_file` given in the constructor is None.
    """
    if self.save_file is None:
      raise RuntimeError('Cannot call incremental_save. `save_file` is None.')
    if log_info:
      logging.info('Saving %d new samples to disk.',
                   len(self.save_to_disk_buffer))
    with tf.gfile.OpenFast(self.save_file, 'a') as f:
      for entry in self.save_to_disk_buffer:
        cPickle.dump(entry, f)
    # Clear the buffer.
    self.save_to_disk_buffer = []
Ejemplo n.º 4
0
  def load(self):
    """Loads GA state from disk.

    Loads whatever is on disk, which will be whatever the most recent call
    to `write` wrote.

    Returns:
      gen: Generation number.
      population: List of Individual objects.
      halloffame: Hall-of-fame buffer. Typically a priority queue.
    """
    with tf.gfile.FastGFile(self.checkpoint_file, 'r') as f:
      raw = f.read()
    objs = cPickle.loads(raw)
    # Validate data.
    assert isinstance(objs, tuple) and len(objs) == 3, (
        'Expecting a 3-tuple, but got %s instead.' % (objs,))
    gen, population, halloffame = objs
    assert isinstance(gen, int), (
        'Expecting `gen` to be an integer, got %s' % (gen,))
    assert (
        isinstance(population, list)
        and len(population) == self.population_size
    ), (
        'Expecting `population` to be a list with size %d, got %s'
        % (self.population_size, population))
    assert halloffame is None or len(halloffame) == 2, (
        'Expecting hall-of-fame object to have length two, got length %d'
        % len(halloffame))
    logging.info('Loaded pop from checkpoint file: "%s".',
                 self.checkpoint_file)
    return gen, population, halloffame
Ejemplo n.º 5
0
  def construct_lookup_variables(self):
    # Materialize negatives for fast lookup sampling.
    start_time = timeit.default_timer()
    inner_bounds = np.argwhere(self._train_pos_users[1:] -
                               self._train_pos_users[:-1])[:, 0] + 1
    (upper_bound,) = self._train_pos_users.shape
    index_bounds = [0] + inner_bounds.tolist() + [upper_bound]
    self._negative_table = np.zeros(shape=(self._num_users, self._num_items),
                                    dtype=rconst.ITEM_DTYPE)

    # Set the table to the max value to make sure the embedding lookup will fail
    # if we go out of bounds, rather than just overloading item zero.
    self._negative_table += np.iinfo(rconst.ITEM_DTYPE).max
    assert self._num_items < np.iinfo(rconst.ITEM_DTYPE).max

    # Reuse arange during generation. np.delete will make a copy.
    full_set = np.arange(self._num_items, dtype=rconst.ITEM_DTYPE)

    self._per_user_neg_count = np.zeros(
        shape=(self._num_users,), dtype=np.int32)

    # Threading does not improve this loop. For some reason, the np.delete
    # call does not parallelize well. Multiprocessing incurs too much
    # serialization overhead to be worthwhile.
    for i in range(self._num_users):
      positives = self._train_pos_items[index_bounds[i]:index_bounds[i+1]]
      negatives = np.delete(full_set, positives)
      self._per_user_neg_count[i] = self._num_items - positives.shape[0]
      self._negative_table[i, :self._per_user_neg_count[i]] = negatives

    logging.info("Negative sample table built. Time: {:.1f} seconds".format(
        timeit.default_timer() - start_time))
Ejemplo n.º 6
0
def make_examples_runner(options):
  """Runs examples creation stage of deepvariant."""
  logging.info('Preparing inputs')
  regions = processing_regions_from_options(options)

  # Create a processor to create candidates and examples for each region.
  region_processor = RegionProcessor(options)

  logging.info('Writing examples to %s', options.examples_filename)
  if options.candidates_filename:
    logging.info('Writing candidates to %s', options.candidates_filename)
  if options.gvcf_filename:
    logging.info('Writing gvcf records to %s', options.gvcf_filename)

  n_regions, n_candidates = 0, 0
  with OutputsWriter(options) as writer:
    for region in regions:
      candidates, examples, gvcfs = region_processor.process(region)
      n_candidates += len(candidates)
      n_regions += 1

      writer.write_candidates(*candidates)

      # If we have any gvcf records, write them out. This if also serves to
      # protect us from trying to write to the gvcfs output of writer when gvcf
      # generation is turned off. In that case, gvcfs will always be empty and
      # we'll never execute the write.
      if gvcfs:
        writer.write_gvcfs(*gvcfs)
      writer.write_examples(*examples)

  logging.info('Found %s candidate variants', n_candidates)
Ejemplo n.º 7
0
def model_init_function(model, num_classes, checkpoint_path):
  """Creates an init_fn for slim.learning.train.

  Args:
    model: DeepVariantModel. The model we want an init_fn for.
    num_classes: int. The number of class labels we want to predict with this
      model.
    checkpoint_path: str or ''/None. A path to a model checkpoint file that we
      will load our model parameters from. If bool(checkpoint_path) == False, we
      not load a checkpoint but rather return None, indicating no initialization
      is needed.

  Returns:
    A init_fn suitable for use with slim.learning.train, or None if
    bool(checkpoint_path) == False.
  """
  # If the special value "model_default" was passed, ask the model for
  # its default.
  if checkpoint_path == 'model_default':
    checkpoint_path = model.pretrained_model_path

  # If the path is non-False, use it.
  if checkpoint_path:
    logging.info('Initializing model from checkpoint at %s', checkpoint_path)
    return model.initialize_from_checkpoint(
        checkpoint_path, num_classes, is_training=True)
  else:
    logging.info('Initializing model with random parameters')
    return None
Ejemplo n.º 8
0
 def _init_fn(unused_scaffold, sess):
   # First initialize every variables.
   sess.run(init_op)
   logging.info('\n'.join([var.name for var in restore_var_list]))
   # Then overwrite variables saved in previous stage.
   if prev_ckpt is not None:
     saver_for_restore.restore(sess, prev_ckpt)
Ejemplo n.º 9
0
def main(unused_argv):
  logging.set_verbosity(FLAGS.log)
  if not os.path.exists(FLAGS.output_dir):
    os.makedirs(FLAGS.output_dir)
  for input_file in sorted(os.listdir(FLAGS.input_dir)):
    if not input_file.endswith('.wav'):
      continue
    wav_filename = input_file
    midi_filename = input_file.replace('.wav', '.mid')
    logging.info('Aligning %s to %s', midi_filename, wav_filename)

    samples = audio_io.load_audio(
        os.path.join(FLAGS.input_dir, wav_filename), align_fine_lib.SAMPLE_RATE)
    ns = midi_io.midi_file_to_sequence_proto(
        os.path.join(FLAGS.input_dir, midi_filename))

    aligned_ns, unused_stats = align_fine_lib.align_cpp(
        samples,
        align_fine_lib.SAMPLE_RATE,
        ns,
        align_fine_lib.CQT_HOP_LENGTH_FINE,
        sf2_path=FLAGS.sf2_path,
        penalty_mul=FLAGS.penalty_mul)

    midi_io.sequence_proto_to_midi_file(
        aligned_ns, os.path.join(FLAGS.output_dir, midi_filename))

  logging.info('Done')
Ejemplo n.º 10
0
 def compile_file_list(self, data_dir, split, load_pose=False):
   """Creates a list of input files."""
   logging.info('data_dir: %s', data_dir)
   with gfile.Open(os.path.join(data_dir, '%s.txt' % split), 'r') as f:
     frames = f.readlines()
     frames = [k.rstrip() for k in frames]
   subfolders = [x.split(' ')[0] for x in frames]
   frame_ids = [x.split(' ')[1] for x in frames]
   image_file_list = [
       os.path.join(data_dir, subfolders[i], frame_ids[i] + '.' +
                    self.file_extension)
       for i in range(len(frames))
   ]
   segment_file_list = [
       os.path.join(data_dir, subfolders[i], frame_ids[i] + '-fseg.' +
                    self.file_extension)
       for i in range(len(frames))
   ]
   cam_file_list = [
       os.path.join(data_dir, subfolders[i], frame_ids[i] + '_cam.txt')
       for i in range(len(frames))
   ]
   file_lists = {}
   file_lists['image_file_list'] = image_file_list
   file_lists['segment_file_list'] = segment_file_list
   file_lists['cam_file_list'] = cam_file_list
   if load_pose:
     pose_file_list = [
         os.path.join(data_dir, subfolders[i], frame_ids[i] + '_pose.txt')
         for i in range(len(frames))
     ]
     file_lists['pose_file_list'] = pose_file_list
   self.steps_per_epoch = len(image_file_list) // self.batch_size
   return file_lists
Ejemplo n.º 11
0
 def _load_lidar_cloud(self):
   lidar_cloud_path = os.path.join(FLAGS.test_srcdir,
                                   icp_util.LIDAR_CLOUD_PATH)
   lidar_cloud = np.load(lidar_cloud_path)
   lidar_cloud = tf.expand_dims(lidar_cloud, axis=0)  # Add batch.
   logging.info('lidar_cloud.shape: %s', lidar_cloud.shape)
   return lidar_cloud
Ejemplo n.º 12
0
def maybe_resolve_conflicting_variants(sorted_variants):
  """Yields Variant protos in sorted order after fixing conflicting haplotypes.

  The input is an iterable of Variants in chromosome and position sorted order,
  with potential incompatibilies as described in this module's docstring. This
  function tries to resolve variants into valid haplotypes, though is not
  guaranteed to do so if the variant composition is not amenable to this or it
  would be computationally intractable.

  Args:
    sorted_variants: Iterable of Variant protos. Sorted in coordinate order, but
      with potentially incompatible haplotypes.

  Yields:
    Variant protos in coordinate-sorted order with no incompatible haplotypes.
  """
  if FLAGS.disable_haplotype_resolution:
    logging.info('disable_haplotype_resolution is True. '
                 '`maybe_resolve_conflicting_variants` has no effect.')
    for v in sorted_variants:
      yield v
  else:
    for overlapping_candidates in _group_overlapping_variants(sorted_variants):
      for resolved_candidate in _maybe_resolve_mixed_calls(
          overlapping_candidates):
        yield resolved_candidate
Ejemplo n.º 13
0
def _WriteFile(output_file, file_string):
  try:
    with open(output_file, 'w') as output:
      logging.info('writing file: %s', output_file)
      output.write(file_string)
  except IOError:
    logging.warn('error while writing file: %s', output_file)
    raise
Ejemplo n.º 14
0
def default_config_with_updates(config_string, do_logging=True):
  if do_logging:
    logging.info('Config string: "%s"', config_string)
  config = default_config()
  config.strict_update(config_lib.Config.parse(config_string))
  if do_logging:
    logging.info('Config:\n%s', config.pretty_str())
  return config
Ejemplo n.º 15
0
 def _wait_to_construct_train_epoch(self):
   count = 0
   while self._train_dataset.buffer_reached() and not self._stop_loop:
     time.sleep(0.01)
     count += 1
     if count >= 100 and np.log10(count) == np.round(np.log10(count)):
       logging.info(
           "Waited {} times for training data to be consumed".format(count))
Ejemplo n.º 16
0
 def setUpClass(cls):
   random_seed = int(time.time())
   value = os.environ.get('TEST_RANDOM_SEED', '')
   try:
     random_seed = int(value)
   except ValueError:
     pass
   logging.info('Seeding random generator with seed %d', random_seed)
   random.seed(random_seed)
Ejemplo n.º 17
0
  def save_replay_buffer(self):
    """Save replay buffer to disk.

    Call this periodically so that training can recover if jobs go down.
    """
    if self.model.experience_replay is not None:
      logging.info('Saving experience replay buffer to "%s".',
                   self.model.experience_replay.save_file)
      self.model.experience_replay.incremental_save(True)
Ejemplo n.º 18
0
def main(argv):
    del argv  # unused

    filter_class_ids = list(map(int, FLAGS.class_ids))
    logging.info('calculating scores for classes %s from %s'
                 % (', '.join(map(str, filter_class_ids)),
                    FLAGS.dataset))

    dataset_class = DATASET_MAP[FLAGS.dataset]
    data_dir = DATA_DIR_MAP[FLAGS.dataset]

    logging.info('loading data...')
    input_image_size = 224
    dataset = dataset_class.load_all_data(
        data_dir, input_image_size, filter_class_ids)

    logging.info('loaded %d images' % len(dataset))

    logging.info('loading model...')
    model = _make_cuda(Vgg16())

    context = Context(
        model=model,
        dataset=dataset,
        layer_idx=FLAGS.layer_idx)

    out_filename = '-'.join([
        f'vgg16_layer{FLAGS.layer_idx}',
        FLAGS.dataset,
        '_'.join(map(str, filter_class_ids)),
        'scores.npz'])
    out_dirpath = os.path.join(SCRATCH_DIR, 'scores')
    os.makedirs(out_dirpath, exist_ok=True)
    out_filepath = os.path.join(out_dirpath, out_filename)

    logging.info('saving output to %s' % out_filepath)

    all_scores_matrix = None
    all_column_ids = list()
    for image_idx in trange(len(context.dataset)):
        if all_scores_matrix is None:
            scores, cols = _get_score_matrix_for_image(
                image_idx, FLAGS.num_max_proposals, context)

            all_scores_matrix = scores
            all_column_ids += cols
        else:
            scores, cols = _get_score_matrix_for_image(
                image_idx, FLAGS.num_max_proposals, context)

            all_scores_matrix = np.concatenate(
                (all_scores_matrix, scores), axis=1)
            all_column_ids += cols

        np.savez(out_filepath, scores=all_scores_matrix, cols=all_column_ids)

    notify(f'Finished: {FLAGS.dataset} - {FLAGS.class_ids}', namespace='scores')
Ejemplo n.º 19
0
def log_msg(msg):
  """Include timestamp info when logging messages to a file."""
  if flags.FLAGS.redirect_logs:
    timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
    absl_logging.info("[{}] {}".format(timestamp, msg))
  else:
    absl_logging.info(msg)
  sys.stdout.flush()
  sys.stderr.flush()
Ejemplo n.º 20
0
def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
                         deterministic=False, epoch_dir=None):
  # type: (str, str, dict, typing.Optional[str], bool, typing.Optional[str]) -> (int, int, data_pipeline.BaseDataConstructor)
  """Load and digest data CSV into a usable form.

  Args:
    dataset: The name of the dataset to be used.
    data_dir: The root directory of the dataset.
    params: dict of parameters for the run.
    constructor_type: The name of the constructor subclass that should be used
      for the input pipeline.
    deterministic: Tell the data constructor to produce deterministically.
    epoch_dir: Directory in which to store the training epochs.
  """
  logging.info("Beginning data preprocessing.")

  st = timeit.default_timer()
  raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE)
  cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE)

  raw_data, _ = _filter_index_sort(raw_rating_path, cache_path)
  user_map, item_map = raw_data["user_map"], raw_data["item_map"]
  num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]

  if num_users != len(user_map):
    raise ValueError("Expected to find {} users, but found {}".format(
        num_users, len(user_map)))
  if num_items != len(item_map):
    raise ValueError("Expected to find {} items, but found {}".format(
        num_items, len(item_map)))

  producer = data_pipeline.get_constructor(constructor_type or "materialized")(
      maximum_number_epochs=params["train_epochs"],
      num_users=num_users,
      num_items=num_items,
      user_map=user_map,
      item_map=item_map,
      train_pos_users=raw_data[rconst.TRAIN_USER_KEY],
      train_pos_items=raw_data[rconst.TRAIN_ITEM_KEY],
      train_batch_size=params["batch_size"],
      batches_per_train_step=params["batches_per_step"],
      num_train_negatives=params["num_neg"],
      eval_pos_users=raw_data[rconst.EVAL_USER_KEY],
      eval_pos_items=raw_data[rconst.EVAL_ITEM_KEY],
      eval_batch_size=params["eval_batch_size"],
      batches_per_eval_step=params["batches_per_step"],
      stream_files=params["use_tpu"],
      deterministic=deterministic,
      epoch_dir=epoch_dir
  )

  run_time = timeit.default_timer() - st
  logging.info("Data preprocessing complete. Time: {:.1f} sec."
               .format(run_time))

  print(producer)
  return num_users, num_items, producer
Ejemplo n.º 21
0
def main(argv):
    del argv  # unused

    preds_out_dirpath = os.path.join(SCRATCH_DIR, 'preds_snorkel')
    os.makedirs(preds_out_dirpath, exist_ok=True)
    class_ids = list(map(int, FLAGS.class_ids))

    preds_out_filename = '-'.join([
        FLAGS.dataset,
        '_'.join(map(str, class_ids)),
        'run_%02d' % FLAGS.run,
        'preds_snorkel.npz'])

    preds_out_filepath = os.path.join(
        preds_out_dirpath, preds_out_filename)

    assert not os.path.exists(preds_out_filepath), \
        'Predictions for this run already exists at %s' % preds_out_filepath

    input_image_size = 224

    if FLAGS.dataset == 'cub':
        dataset = CUBDataset.load_all_data(
            CUB_DATA_DIR, input_image_size, class_ids)
    elif FLAGS.dataset == 'awa2':
        dataset = AwA2Dataset.load_all_data(
            AWA2_DATA_DIR, input_image_size, class_ids)

    y_true = [v[1] for v in dataset]

    seed = sum(v * (10 ** (3 * i))
               for i, v in enumerate(class_ids + [FLAGS.run]))
    random.seed(seed)
    np.random.seed(seed)

    scores, col_ids = load_scores(
        os.path.join(
            SCRATCH_DIR, 'scores',
            f'vgg16_layer30-{FLAGS.dataset}-%d_%d-scores.npz'
            % tuple(class_ids)))

    new_scores_np = get_labeling_matrix_for_GOOGGLES(scores)

    L_tr, L_te = new_scores_np, new_scores_np
    _, y_snorkel, _ = train_snorkel_gen_model(L_tr, L_te)

    np.savez(preds_out_filepath,
             y_true=y_true, y_snorkel=y_snorkel)

    logging.info(f'saved predictions at {preds_out_filepath}')

    snorkel_acc = best_acc(y_true, y_snorkel)

    notify(f'`{FLAGS.dataset}` - `%s` - `run {FLAGS.run}`: '
           f'{snorkel_acc}'
           % ', '.join(map(str, class_ids)),
           namespace='goggles-snorkel')
Ejemplo n.º 22
0
def _run_inference():
  """Runs all images through depth model and saves depth maps."""
  ckpt_basename = os.path.basename(FLAGS.model_ckpt)
  ckpt_modelname = os.path.basename(os.path.dirname(FLAGS.model_ckpt))
  output_dir = os.path.join(FLAGS.output_dir,
                            FLAGS.kitti_video.replace('/', '_') + '_' +
                            ckpt_modelname + '_' + ckpt_basename)
  if not gfile.Exists(output_dir):
    gfile.MakeDirs(output_dir)
  inference_model = model.Model(is_training=False,
                                seq_length=FLAGS.seq_length,
                                batch_size=FLAGS.batch_size,
                                img_height=FLAGS.img_height,
                                img_width=FLAGS.img_width)
  vars_to_restore = util.get_vars_to_restore(FLAGS.model_ckpt)
  saver = tf.train.Saver(vars_to_restore)
  sv = tf.train.Supervisor(logdir='/tmp/', saver=None)
  with sv.managed_session() as sess:
    saver.restore(sess, FLAGS.model_ckpt)
    if FLAGS.kitti_video == 'test_files_eigen':
      im_files = util.read_text_lines(
          util.get_resource_path('dataset/kitti/test_files_eigen.txt'))
      im_files = [os.path.join(FLAGS.kitti_dir, f) for f in im_files]
    else:
      video_path = os.path.join(FLAGS.kitti_dir, FLAGS.kitti_video)
      im_files = gfile.Glob(os.path.join(video_path, 'image_02/data', '*.png'))
      im_files = [f for f in im_files if 'disp' not in f]
      im_files = sorted(im_files)
    for i in range(0, len(im_files), FLAGS.batch_size):
      if i % 100 == 0:
        logging.info('Generating from %s: %d/%d', ckpt_basename, i,
                     len(im_files))
      inputs = np.zeros(
          (FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, 3),
          dtype=np.uint8)
      for b in range(FLAGS.batch_size):
        idx = i + b
        if idx >= len(im_files):
          break
        im = scipy.misc.imread(im_files[idx])
        inputs[b] = scipy.misc.imresize(im, (FLAGS.img_height, FLAGS.img_width))
      results = inference_model.inference(inputs, sess, mode='depth')
      for b in range(FLAGS.batch_size):
        idx = i + b
        if idx >= len(im_files):
          break
        if FLAGS.kitti_video == 'test_files_eigen':
          depth_path = os.path.join(output_dir, '%03d.png' % idx)
        else:
          depth_path = os.path.join(output_dir, '%04d.png' % idx)
        depth_map = results['depth'][b]
        depth_map = np.squeeze(depth_map)
        colored_map = _normalize_depth_for_display(depth_map, cmap=CMAP)
        input_float = inputs[b].astype(np.float32) / 255.0
        vertical_stack = np.concatenate((input_float, colored_map), axis=0)
        scipy.misc.imsave(depth_path, vertical_stack)
Ejemplo n.º 23
0
 def _compute_best_reward(self):
   io_seqs = self.task.make_io_set()
   reward = 0.0
   for _, output_seq in io_seqs:
     reward += self.reward_fn(output_seq, output_seq, self.task.base)
     reward += self.correct_bonus
     reward += self.code_length_bonus  # Bonus for shortest code.
   self.best_reward = reward
   self.good_reward = 0.75 * reward
   logging.info('Known best reward: %.4f', self.best_reward)
Ejemplo n.º 24
0
 def assertRealisticImage(self, image_path):
   logging.info('Testing %s for realism.', image_path)
   # If the normalization is off or forgotten, then the generated image is
   # all one pixel value. This tests that different pixel values are achieved.
   input_np = np.asarray(PIL.Image.open(image_path))
   self.assertEqual(len(input_np.shape), 3)
   self.assertGreaterEqual(input_np.shape[0], 50)
   self.assertGreaterEqual(input_np.shape[1], 50)
   self.assertGreater(np.mean(input_np), 20)
   self.assertGreater(np.var(input_np), 100)
Ejemplo n.º 25
0
  def delete_replay_buffer(self):
    """Delete replay buffer from disk.

    Call this at the end of training to clean up. Replay buffer can get very
    large.
    """
    if self.model.experience_replay is not None:
      logging.info('Deleting experience replay buffer at "%s".',
                   self.model.experience_replay.save_file)
      tf.gfile.Remove(self.model.experience_replay.save_file)
Ejemplo n.º 26
0
  def save_topk_buffer(self):
    """Save top-k buffer to disk.

    Call this periodically so that training can recover if jobs go down.
    """
    if self.model.top_episodes is not None:
      logging.info('Saving top-k buffer to "%s".', self.topk_file)
      # Overwrite previous data each time.
      with tf.gfile.FastGFile(self.topk_file, 'w') as f:
        f.write(cPickle.dumps(self.model.top_episodes))
Ejemplo n.º 27
0
Archivo: util.py Proyecto: pcm17/models
def get_vars_to_save_and_restore(ckpt=None):
  """Returns list of variables that should be saved/restored.

  Args:
    ckpt: Path to existing checkpoint.  If present, returns only the subset of
        variables that exist in given checkpoint.

  Returns:
    List of all variables that need to be saved/restored.
  """
  model_vars = tf.trainable_variables()
  # Add batchnorm variables.
  bn_vars = [v for v in tf.global_variables()
             if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name or
             'mu' in v.op.name or 'sigma' in v.op.name or
             'global_scale_var' in v.op.name]
  model_vars.extend(bn_vars)
  model_vars = sorted(model_vars, key=lambda x: x.op.name)
  mapping = {}
  if ckpt is not None:
    ckpt_var = tf.contrib.framework.list_variables(ckpt)
    ckpt_var_names = [name for (name, unused_shape) in ckpt_var]
    ckpt_var_shapes = [shape for (unused_name, shape) in ckpt_var]
    not_loaded = list(ckpt_var_names)
    for v in model_vars:
      if v.op.name not in ckpt_var_names:
        # For backward compatibility, try additional matching.
        v_additional_name = v.op.name.replace('egomotion_prediction/', '')
        if v_additional_name in ckpt_var_names:
          # Check if shapes match.
          ind = ckpt_var_names.index(v_additional_name)
          if ckpt_var_shapes[ind] == v.get_shape():
            mapping[v_additional_name] = v
            not_loaded.remove(v_additional_name)
            continue
          else:
            logging.warn('Shape mismatch, will not restore %s.', v.op.name)
        logging.warn('Did not find var %s in checkpoint: %s', v.op.name,
                     os.path.basename(ckpt))
      else:
        # Check if shapes match.
        ind = ckpt_var_names.index(v.op.name)
        if ckpt_var_shapes[ind] == v.get_shape():
          mapping[v.op.name] = v
          not_loaded.remove(v.op.name)
        else:
          logging.warn('Shape mismatch, will not restore %s.', v.op.name)
    if not_loaded:
      logging.warn('The following variables in the checkpoint were not loaded:')
      for varname_not_loaded in not_loaded:
        logging.info('%s', varname_not_loaded)
  else:  # just get model vars.
    for v in model_vars:
      mapping[v.op.name] = v
  return mapping
Ejemplo n.º 28
0
  def _TranslatePolicy(self, pol, exp_info):
    self.juniper_policies = []
    current_date = datetime.datetime.utcnow().date()
    exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

    for header, terms in pol.filters:
      if self._PLATFORM not in header.platforms:
        continue

      filter_options = header.FilterOptions(self._PLATFORM)
      filter_name = header.FilterName(self._PLATFORM)

      # Check for the position independent options and remove them from
      # the list.
      interface_specific = 'not-interface-specific' not in filter_options[1:]
      enable_dsmo = 'enable_dsmo' in filter_options[1:]
      noverbose = 'noverbose' in filter_options[1:]

      if not interface_specific:
        filter_options.remove('not-interface-specific')
      if enable_dsmo:
        filter_options.remove('enable_dsmo')

      # default to inet4 filters
      filter_type = 'inet'
      if len(filter_options) > 1:
        filter_type = filter_options[1]

      term_names = set()
      new_terms = []
      for term in terms:
        term.name = self.FixTermLength(term.name)
        if term.name in term_names:
          raise JuniperDuplicateTermError('You have multiple terms named: %s' %
                                          term.name)
        term_names.add(term.name)

        term = self.FixHighPorts(term, af=filter_type)
        if not term:
          continue

        if term.expiration:
          if term.expiration <= exp_info_date:
            logging.info('INFO: Term %s in policy %s expires '
                         'in less than two weeks.', term.name, filter_name)
          if term.expiration <= current_date:
            logging.warn('WARNING: Term %s in policy %s is expired and '
                         'will not be rendered.', term.name, filter_name)
            continue

        new_terms.append(self._TERM(term, filter_type, enable_dsmo, noverbose))

      self.juniper_policies.append((header, filter_name, filter_type,
                                    interface_specific, new_terms))
Ejemplo n.º 29
0
  def __init__(self, output_path, **kwargs):
    super(DispatchingGenomicsWriter, self).__init__()
    self.header = kwargs.get('header', None)

    if '.tfrecord' in output_path:
      self._writer = TFRecordWriter(output_path, header=self.header)
    else:
      self._writer = self._native_writer(output_path, **kwargs)
    logging.info('Writing %s with %s',
                 output_path, self._writer.__class__.__name__)
    self._post_init_hook()
Ejemplo n.º 30
0
 def testBadAddressFamily(self):
   cases = [
       'chain_name input 0 mixed',
   ]
   for case in cases:
     logging.info('Testing bad address family case %s.', case)
     header = BAD_HEADER % case
     pol = policy.ParsePolicy(header + GOOD_TERM_1, self.mock_naming)
     self.assertRaises(aclgenerator.UnsupportedAF,
                       nftables.Nftables.__init__,
                       nftables.Nftables.__new__(nftables.Nftables),
                       pol, EXP_INFO)
def mixtures_same_family(draw,
                         batch_shape=None,
                         event_dim=None,
                         enable_vars=False,
                         depth=None):
  """Strategy for drawing `MixtureSameFamily` distributions.

  The component distribution is drawn from the `distributions` strategy.

  The Categorical mixture distributions are either shared across all batch
  members, or drawn independently for the full batch (as required by
  `MixtureSameFamily`).

  Args:
    draw: Hypothesis MacGuffin.  Supplied by `@hps.composite`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      `MixtureSameFamily` distribution.  The component distribution will have a
      batch shape of 1 rank higher (for the components being mixed).  Hypothesis
      will pick a batch shape if omitted.
    event_dim: Optional Python int giving the size of each of the component
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all Tensors, never Variables or DeferredTensor.
    depth: Python `int` giving maximum nesting depth of compound Distributions.

  Returns:
    dists: A strategy for drawing `MixtureSameFamily` distributions with the
      specified `batch_shape` (or an arbitrary one if omitted).
  """
  if depth is None:
    depth = draw(depths())

  if batch_shape is None:
    # Ensure the components dist has at least one batch dim (a component dim).
    batch_shape = draw(tfp_hps.shapes(min_ndims=1, min_lastdimsize=2))
  else:  # This mixture adds a batch dim to its underlying components dist.
    batch_shape = tensorshape_util.concatenate(
        batch_shape,
        draw(tfp_hps.shapes(min_ndims=1, max_ndims=1, min_lastdimsize=2)))

  component_dist = draw(
      distributions(
          batch_shape=batch_shape,
          event_dim=event_dim,
          enable_vars=enable_vars,
          depth=depth - 1))
  logging.info(
      'component distribution: %s; parameters used: %s', component_dist,
      [k for k, v in six.iteritems(component_dist.parameters) if v is not None])
  # scalar or same-shaped categorical?
  mixture_batch_shape = draw(
      hps.one_of(hps.just(batch_shape[:-1]), hps.just(tf.TensorShape([]))))
  mixture_dist = draw(base_distributions(
      dist_name='Categorical',
      batch_shape=mixture_batch_shape,
      event_dim=tensorshape_util.as_list(batch_shape)[-1],
      enable_vars=enable_vars))
  logging.info(
      'mixture distribution: %s; parameters used: %s', mixture_dist,
      [k for k, v in six.iteritems(mixture_dist.parameters) if v is not None])
  result_dist = tfd.MixtureSameFamily(
      components_distribution=component_dist,
      mixture_distribution=mixture_dist,
      validate_args=True)
  if batch_shape[:-1] != result_dist.batch_shape:
    msg = ('TransformedDistribution strategy generated a bad batch shape '
           'for {}, should have been {}.').format(result_dist, batch_shape[:-1])
    raise AssertionError(msg)
  return result_dist
Ejemplo n.º 32
0
def main(args):
    # If output_model path is relative and in cwd, make it absolute from root
    output_model = FLAGS.output_model
    if str(Path(output_model).parent) == '.':
        output_model = str((Path.cwd() / output_model))

    output_fld = Path(output_model).parent
    output_model_name = Path(output_model).name
    output_model_stem = Path(output_model).stem
    output_model_pbtxt_name = output_model_stem + '.pbtxt'

    # Create output directory if it does not exist
    Path(output_model).parent.mkdir(parents=True, exist_ok=True)

    if FLAGS.channels_first:
        K.set_image_data_format('channels_first')
    else:
        K.set_image_data_format('channels_last')

    model = load_input_model(FLAGS.input_model, FLAGS.input_model_json,
                             FLAGS.input_model_yaml)

    # TODO(amirabdi): Support networks with multiple inputs
    orig_output_node_names = [node.op.name for node in model.outputs]
    if FLAGS.output_nodes_prefix:
        num_output = len(orig_output_node_names)
        pred = [None] * num_output
        converted_output_node_names = [None] * num_output

        # Create dummy tf nodes to rename output
        for i in range(num_output):
            converted_output_node_names[i] = '{}{}'.format(
                FLAGS.output_nodes_prefix, i)
            pred[i] = tf.identity(model.outputs[i],
                                  name=converted_output_node_names[i])
    else:
        converted_output_node_names = orig_output_node_names
    logging.info('Converted output node names are: %s',
                 str(converted_output_node_names))

    sess = K.get_session()
    if FLAGS.output_meta_ckpt:
        saver = tf.train.Saver()
        saver.save(sess, str(output_fld / output_model_stem))

    if FLAGS.save_graph_def:
        tf.train.write_graph(sess.graph.as_graph_def(),
                             str(output_fld),
                             output_model_pbtxt_name,
                             as_text=True)
        logging.info('Saved the graph definition in ascii format at %s',
                     str(Path(output_fld) / output_model_pbtxt_name))

    if FLAGS.quantize:
        from tensorflow.tools.graph_transforms import TransformGraph
        transforms = ["quantize_weights", "quantize_nodes"]
        transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [],
                                               converted_output_node_names,
                                               transforms)
        constant_graph = graph_util.convert_variables_to_constants(
            sess, transformed_graph_def, converted_output_node_names)
    else:
        constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), converted_output_node_names)

    graph_io.write_graph(constant_graph,
                         str(output_fld),
                         output_model_name,
                         as_text=False)
    logging.info('Saved the freezed graph at %s',
                 str(Path(output_fld) / output_model_name))
Ejemplo n.º 33
0
def LogMsg(msg):
    logging.info(msg)
Ejemplo n.º 34
0
def convert_examples_to_features(examples,
                                 tokenizer,
                                 max_seq_length,
                                 doc_stride,
                                 max_query_length,
                                 is_training,
                                 output_fn,
                                 batch_size=None):
    """Loads a data file into a list of `InputBatch`s."""

    base_id = 1000000000
    unique_id = base_id
    feature = None
    for (example_index, example) in enumerate(examples):
        query_tokens = tokenizer.tokenize(example.question_text)

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        tok_start_position = None
        tok_end_position = None
        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position +
                                                     1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens, tok_start_position, tok_end_position,
                tokenizer, example.orig_answer_text)

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")
            segment_ids.append(0)

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(
                    tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans,
                                                       doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            start_position = None
            end_position = None
            if is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start
                        and tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset

            if is_training and example.is_impossible:
                start_position = 0
                end_position = 0

            if example_index < 20:
                logging.info("*** Example ***")
                logging.info("unique_id: %s", (unique_id))
                logging.info("example_index: %s", (example_index))
                logging.info("doc_span_index: %s", (doc_span_index))
                logging.info(
                    "tokens: %s",
                    " ".join([tokenization.printable_text(x) for x in tokens]))
                logging.info(
                    "token_to_orig_map: %s", " ".join([
                        "%d:%d" % (x, y)
                        for (x, y) in six.iteritems(token_to_orig_map)
                    ]))
                logging.info(
                    "token_is_max_context: %s", " ".join([
                        "%d:%s" % (x, y)
                        for (x, y) in six.iteritems(token_is_max_context)
                    ]))
                logging.info("input_ids: %s",
                             " ".join([str(x) for x in input_ids]))
                logging.info("input_mask: %s",
                             " ".join([str(x) for x in input_mask]))
                logging.info("segment_ids: %s",
                             " ".join([str(x) for x in segment_ids]))
                if is_training and example.is_impossible:
                    logging.info("impossible example")
                if is_training and not example.is_impossible:
                    answer_text = " ".join(
                        tokens[start_position:(end_position + 1)])
                    logging.info("start_position: %d", (start_position))
                    logging.info("end_position: %d", (end_position))
                    logging.info("answer: %s",
                                 tokenization.printable_text(answer_text))

            feature = InputFeatures(unique_id=unique_id,
                                    example_index=example_index,
                                    doc_span_index=doc_span_index,
                                    tokens=tokens,
                                    token_to_orig_map=token_to_orig_map,
                                    token_is_max_context=token_is_max_context,
                                    input_ids=input_ids,
                                    input_mask=input_mask,
                                    segment_ids=segment_ids,
                                    start_position=start_position,
                                    end_position=end_position,
                                    is_impossible=example.is_impossible)

            # Run callback
            if is_training:
                output_fn(feature)
            else:
                output_fn(feature, is_padding=False)

            unique_id += 1

    if not is_training and feature:
        assert batch_size
        num_padding = 0
        num_examples = unique_id - base_id
        if unique_id % batch_size != 0:
            num_padding = batch_size - (num_examples % batch_size)
        logging.info("Adding padding examples to make sure no partial batch.")
        logging.info("Adds %d padding examples for inference.", num_padding)
        dummy_feature = copy.deepcopy(feature)
        for _ in range(num_padding):
            dummy_feature.unique_id = unique_id

            # Run callback
            output_fn(feature, is_padding=True)
            unique_id += 1
    return unique_id - base_id
Ejemplo n.º 35
0
def main(_argv):
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    if len(physical_devices) > 0:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)

    if FLAGS.tiny:
        model = YoloV3Tiny(FLAGS.size, training=True,
                           classes=FLAGS.num_classes)
        anchors = yolo_tiny_anchors
        anchor_masks = yolo_tiny_anchor_masks
    else:
        model = YoloV3(FLAGS.size, training=True, classes=FLAGS.num_classes)
        anchors = yolo_anchors
        anchor_masks = yolo_anchor_masks

    train_dataset = dataset.load_fake_dataset()
    if FLAGS.dataset:
        train_dataset = dataset.load_tfrecord_dataset(
            FLAGS.dataset, FLAGS.classes)
    train_dataset = train_dataset.shuffle(buffer_size=1024)  # TODO: not 1024
    train_dataset = train_dataset.batch(FLAGS.batch_size)
    train_dataset = train_dataset.map(lambda x, y: (
        dataset.transform_images(x, FLAGS.size),
        dataset.transform_targets(y, anchors, anchor_masks, 80)))
    train_dataset = train_dataset.prefetch(
        buffer_size=tf.data.experimental.AUTOTUNE)

    val_dataset = dataset.load_fake_dataset()
    if FLAGS.val_dataset:
        val_dataset = dataset.load_tfrecord_dataset(
            FLAGS.val_dataset, FLAGS.classes)
    val_dataset = val_dataset.batch(FLAGS.batch_size)
    val_dataset = val_dataset.map(lambda x, y: (
        dataset.transform_images(x, FLAGS.size),
        dataset.transform_targets(y, anchors, anchor_masks, 80)))

    if FLAGS.transfer != 'none':
        model.load_weights(FLAGS.weights)
        if FLAGS.transfer == 'fine_tune':
            # freeze darknet
            darknet = model.get_layer('yolo_darknet')
            freeze_all(darknet)
        elif FLAGS.transfer == 'frozen':
            # freeze everything
            freeze_all(model)
        else:
            # reset top layers
            if FLAGS.tiny:  # get initial weights
                init_model = YoloV3Tiny(
                    FLAGS.size, training=True, classes=FLAGS.num_classes)
            else:
                init_model = YoloV3(
                    FLAGS.size, training=True, classes=FLAGS.num_classes)

            if FLAGS.transfer == 'darknet':
                for l in model.layers:
                    if l.name != 'yolo_darknet' and l.name.startswith('yolo_'):
                        l.set_weights(init_model.get_layer(
                            l.name).get_weights())
                    else:
                        freeze_all(l)
            elif FLAGS.transfer == 'no_output':
                for l in model.layers:
                    if l.name.startswith('yolo_output'):
                        l.set_weights(init_model.get_layer(
                            l.name).get_weights())
                    else:
                        freeze_all(l)

    optimizer = tf.keras.optimizers.Adam(lr=FLAGS.learning_rate)
    loss = [YoloLoss(anchors[mask], classes=FLAGS.num_classes)
            for mask in anchor_masks]

    if FLAGS.mode == 'eager_tf':
        # Eager mode is great for debugging
        # Non eager graph mode is recommended for real training
        avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
        avg_val_loss = tf.keras.metrics.Mean('val_loss', dtype=tf.float32)

        for epoch in range(1, FLAGS.epochs + 1):
            for batch, (images, labels) in enumerate(train_dataset):
                with tf.GradientTape() as tape:
                    outputs = model(images, training=True)
                    regularization_loss = tf.reduce_sum(model.losses)
                    pred_loss = []
                    for output, label, loss_fn in zip(outputs, labels, loss):
                        pred_loss.append(loss_fn(label, output))
                    total_loss = tf.reduce_sum(pred_loss) + regularization_loss

                grads = tape.gradient(total_loss, model.trainable_variables)
                optimizer.apply_gradients(
                    zip(grads, model.trainable_variables))

                logging.info("{}_train_{}, {}, {}".format(
                    epoch, batch, total_loss.numpy(),
                    list(map(lambda x: np.sum(x.numpy()), pred_loss))))
                avg_loss.update_state(total_loss)

            for batch, (images, labels) in enumerate(val_dataset):
                outputs = model(images)
                regularization_loss = tf.reduce_sum(model.losses)
                pred_loss = []
                for output, label, loss_fn in zip(outputs, labels, loss):
                    pred_loss.append(loss_fn(label, output))
                total_loss = tf.reduce_sum(pred_loss) + regularization_loss

                logging.info("{}_val_{}, {}, {}".format(
                    epoch, batch, total_loss.numpy(),
                    list(map(lambda x: np.sum(x.numpy()), pred_loss))))
                avg_val_loss.update_state(total_loss)

            logging.info("{}, train: {}, val: {}".format(
                epoch,
                avg_loss.result().numpy(),
                avg_val_loss.result().numpy()))

            avg_loss.reset_states()
            avg_val_loss.reset_states()
            model.save_weights(
                'checkpoints/yolov3_train_{}.tf'.format(epoch))
    else:
        model.compile(optimizer=optimizer, loss=loss,
                      run_eagerly=(FLAGS.mode == 'eager_fit'))

        callbacks = [
            ReduceLROnPlateau(verbose=1),
            EarlyStopping(patience=3, verbose=1),
            ModelCheckpoint('checkpoints/yolov3_train_{epoch}.tf',
                            verbose=1, save_weights_only=True),
            TensorBoard(log_dir='logs')
        ]

        history = model.fit(train_dataset,
                            epochs=FLAGS.epochs,
                            callbacks=callbacks,
                            validation_data=val_dataset)
Ejemplo n.º 36
0
def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
    """Project the tokenized prediction back to the original text."""

    # When we created the data, we kept track of the alignment between original
    # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
    # now `orig_text` contains the span of our original text corresponding to the
    # span that we predicted.
    #
    # However, `orig_text` may contain extra characters that we don't want in
    # our prediction.
    #
    # For example, let's say:
    #   pred_text = steve smith
    #   orig_text = Steve Smith's
    #
    # We don't want to return `orig_text` because it contains the extra "'s".
    #
    # We don't want to return `pred_text` because it's already been normalized
    # (the SQuAD eval script also does punctuation stripping/lower casing but
    # our tokenizer does additional normalization like stripping accent
    # characters).
    #
    # What we really want to return is "Steve Smith".
    #
    # Therefore, we have to apply a semi-complicated alignment heruistic between
    # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
    # can fail in certain cases in which case we just return `orig_text`.

    def _strip_spaces(text):
        ns_chars = []
        ns_to_s_map = collections.OrderedDict()
        for (i, c) in enumerate(text):
            if c == " ":
                continue
            ns_to_s_map[len(ns_chars)] = i
            ns_chars.append(c)
        ns_text = "".join(ns_chars)
        return (ns_text, ns_to_s_map)

    # We first tokenize `orig_text`, strip whitespace from the result
    # and `pred_text`, and check if they are the same length. If they are
    # NOT the same length, the heuristic has failed. If they are the same
    # length, we assume the characters are one-to-one aligned.
    tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)

    tok_text = " ".join(tokenizer.tokenize(orig_text))

    start_position = tok_text.find(pred_text)
    if start_position == -1:
        if verbose:
            logging.info("Unable to find text: '%s' in '%s'", pred_text,
                         orig_text)
        return orig_text
    end_position = start_position + len(pred_text) - 1

    (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
    (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

    if len(orig_ns_text) != len(tok_ns_text):
        if verbose:
            logging.info(
                "Length not equal after stripping spaces: '%s' vs '%s'",
                orig_ns_text, tok_ns_text)
        return orig_text

    # We then project the characters in `pred_text` back to `orig_text` using
    # the character-to-character alignment.
    tok_s_to_ns_map = {}
    for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
        tok_s_to_ns_map[tok_index] = i

    orig_start_position = None
    if start_position in tok_s_to_ns_map:
        ns_start_position = tok_s_to_ns_map[start_position]
        if ns_start_position in orig_ns_to_s_map:
            orig_start_position = orig_ns_to_s_map[ns_start_position]

    if orig_start_position is None:
        if verbose:
            logging.info("Couldn't map start position")
        return orig_text

    orig_end_position = None
    if end_position in tok_s_to_ns_map:
        ns_end_position = tok_s_to_ns_map[end_position]
        if ns_end_position in orig_ns_to_s_map:
            orig_end_position = orig_ns_to_s_map[ns_end_position]

    if orig_end_position is None:
        if verbose:
            logging.info("Couldn't map end position")
        return orig_text

    output_text = orig_text[orig_start_position:(orig_end_position + 1)]
    return output_text
Ejemplo n.º 37
0
  def Read(self,
           split: str,
           batch_size: int = None,
           input_filepattern=None,
           shuffle_buffer_size: int = None,
           num_examples: int = -1,
           num_epochs: int = -1,
           read_parallelism: int = 16) -> tf.data.Dataset:
    """Reads the specified `split` as a `tf.data.Dataset`.

    Reads the tfrecords at the path configured for `split`, parses the examples,
    and optionally shuffles and batches them.

    By default, the tfrecords path is taken from the `split_paths` passed to
    the constructor. Callers can optionally override this path by passing
    `input_filepattern` explicitly.

    If `shuffle_buffer_size` is set, examples are randomly shuffled with a
    buffer of the given size prior to batching. Shuffling is on by default for
    the `TRAIN` split, and can be disabled by setting shuffle_buffer_size=0.

    Args:
      split: Name of the split to read.
      batch_size: If set, the number of examples to batch together.
      input_filepattern: If given, read the tfrecords at this path instead of
        the one specified in `split_paths`.
      shuffle_buffer_size: If > 0, size of the buffer used to shuffle examples.
        Defaults to 4096 for the `TRAIN` split and 0 otherwise.
      num_examples: Number of examples to read from the underlying tfrecords.
        Defaults to -1, meaning all examples are read.
      num_epochs: Number of epochs (full passes) through the dataset to make.
        Defaults to -1, meaning the dataset repeats indefinitely. If set to n >
        0, `tf.errors.OutOfRangeError` at the end of the n'th epoch.
      read_parallelism: Number of input files to read in parallel.

    Returns:
      The split as a `tf.data.Dataset` object.
    """
    if input_filepattern is None:
      input_filepattern = self._split_paths.get(split)
      assert input_filepattern, 'Unsupported split {}'.format(split)

    if shuffle_buffer_size is None:
      shuffle_buffer_size = 4096 if split == constants.Split.TRAIN else 0
    shuffle_files = shuffle_buffer_size > 0

    logging.info(
        'Reading inputs %s with batch_size=%s, shuffle_buffer_size=%s,'
        'num_examples=%d, num_epochs=%d', input_filepattern, batch_size,
        shuffle_buffer_size, num_examples, num_epochs)

    per_file_dataset_factory = functools.partial(
        tf.data.TFRecordDataset, buffer_size=_TFRECORD_READER_BUFFER_SIZE_BYTES)
    dataset = (
        tf.data.Dataset.list_files(input_filepattern,
                                   shuffle=shuffle_files).apply(
                                       tf.data.experimental.parallel_interleave(
                                           per_file_dataset_factory,
                                           cycle_length=read_parallelism,
                                           sloppy=shuffle_files)))

    if shuffle_buffer_size > 0:
      dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.take(num_examples).repeat(num_epochs)
    if batch_size is not None:
      dataset = dataset.batch(batch_size, drop_remainder=True)

    dataset = dataset.map(
        functools.partial(tf.io.parse_example, features=self._schema),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if self._transform is not None:
      dataset = dataset.map(self._transform)
    return dataset
Ejemplo n.º 38
0
def eval_step(model, batch):
    eval_keys = ['inputs', 'targets']
    (inputs, targets) = [batch.get(k, None) for k in eval_keys]
    logits = model(inputs, train=False)
    logging.info(logits)
    return compute_metrics(logits, targets, None)
Ejemplo n.º 39
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    config = FLAGS.config
    logging.info('===========Config Dict============')
    logging.info(config)
    batch_size = config.batch_size
    learning_rate = config.learning_rate
    num_train_steps = config.num_train_steps
    num_eval_steps = config.num_eval_steps
    eval_freq = config.eval_frequency
    random_seed = config.random_seed
    model_type = config.model_type

    max_length = config.max_length

    if jax.process_index() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'summary'))

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_ds, eval_ds, test_ds, encoder = input_pipeline.get_tc_datasets(
        n_devices=jax.local_device_count(),
        task_name=FLAGS.task_name,
        data_dir=FLAGS.data_dir,
        batch_size=batch_size,
        fixed_vocab=None,
        max_length=max_length)

    vocab_size = encoder.vocab_size
    logging.info('Vocab Size: %d', vocab_size)

    train_ds = train_ds.repeat()

    train_iter = iter(train_ds)
    input_shape = (batch_size, max_length)

    model_kwargs = {
        'vocab_size': vocab_size,
        'emb_dim': config.emb_dim,
        'num_heads': config.num_heads,
        'num_layers': config.num_layers,
        'qkv_dim': config.qkv_dim,
        'mlp_dim': config.mlp_dim,
        'max_len': max_length,
        'classifier': True,
        'num_classes': CLASS_MAP[FLAGS.task_name],
        'classifier_pool': config.classifier_pool
    }

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.process_index())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    model = train_utils.get_model(model_type, create_model, model_kwargs,
                                  init_rng, input_shape)

    optimizer = create_optimizer(model,
                                 learning_rate,
                                 weight_decay=FLAGS.config.weight_decay)
    del model  # Don't keep a copy of the initial model.
    start_step = 0
    if config.restore_checkpoints or FLAGS.test_only:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors=config.factors,
        base_learning_rate=learning_rate,
        warmup_steps=config.warmup)
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # p_pred_step = jax.pmap(predict_step, axis_name='batch')

    def run_eval(eval_ds, num_eval_steps=-1):
        eval_metrics = []
        eval_iter = iter(eval_ds)
        if num_eval_steps == -1:
            num_iter = itertools.count()
        else:
            num_iter = range(num_eval_steps)
        for _, eval_batch in zip(num_iter, eval_iter):
            # pylint: disable=protected-access
            eval_batch = common_utils.shard(
                jax.tree_map(lambda x: x._numpy(), eval_batch))
            # pylint: enable=protected-access
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        # Calculate (clipped) perplexity after averaging log-perplexities:
        eval_summary['perplexity'] = jnp.clip(jnp.exp(eval_summary['loss']),
                                              a_max=1.0e4)
        return eval_summary

    if FLAGS.test_only:
        with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.json'),
                               'w') as f:
            test_summary = run_eval(test_ds)
            json.dump(jax.tree_map(lambda x: x.tolist(), test_summary), f)
        return

    metrics_all = []
    tick = time.time()
    logging.info('Starting training')
    logging.info('====================')

    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)
        logging.info('train in step: %d', step)

        # Save a Checkpoint
        if ((step % config.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.process_index() == 0 and config.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            if jax.process_index() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('steps per second', steps_per_sec, step)
                for key, val in summary.items():
                    summary_writer.scalar(f'train_{key}', val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_summary = run_eval(eval_ds, num_eval_steps)
            logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])
            if jax.process_index() == 0:
                for key, val in eval_summary.items():
                    summary_writer.scalar(f'eval_{key}', val, step)
                summary_writer.flush()
Ejemplo n.º 40
0
def create_if_not_exists(path):
    try:
        tf.io.gfile.makedirs(path)
    except tf.errors.OpError:
        logging.info('Skipping creation of directory [%s], already exists',
                     path)
def summarize_logs(dir, markdown=False, github_log=False):
  build_log_files = glob.glob(os.path.join(dir, BUILD_FILE_PATTERN))
  test_log_files = glob.glob(os.path.join(dir, TEST_FILE_PATTERN))
  # Replace the "*" in the file glob with a regex capture group,
  # so we can report the test name.
  build_log_name_re = re.escape(
      os.path.join(dir,BUILD_FILE_PATTERN)).replace("\\*", "(.*)")
  test_log_name_re = re.escape(
      os.path.join(dir,TEST_FILE_PATTERN)).replace("\\*", "(.*)")

  success_or_only_flakiness = True
  log_data = {}
  # log_data format:
  #   { testapps: {"build": [configs]},
  #               {"test": {"errors": [configs]},
  #                        {"failures": {failed_test: [configs]}},
  #                        {"flakiness": {flaky_test: [configs]}}}}
  all_tested_configs = { "build_configs": [], "test_configs": []}
  for build_log_file in build_log_files:
    configs = get_configs_from_file_name(build_log_file, build_log_name_re)
    all_tested_configs["build_configs"].append(configs)
    with open(build_log_file, "r") as log_reader:
      log_text = log_reader.read()
      if "__SUMMARY_MISSING__" in log_text:
        success_or_only_flakiness = False
        log_data.setdefault(MISSING_LOG, {}).setdefault("build", []).append(configs)
      else:
        log_reader_data = json.loads(log_text)
        for (testapp, _) in log_reader_data["errors"].items():
          success_or_only_flakiness = False
          log_data.setdefault(testapp, {}).setdefault("build", []).append(configs)

  for test_log_file in test_log_files:
    configs = get_configs_from_file_name(test_log_file, test_log_name_re)
    all_tested_configs["test_configs"].append(configs)
    with open(test_log_file, "r") as log_reader:
      log_text = log_reader.read()
      if "__SUMMARY_MISSING__" in log_text:
        success_or_only_flakiness = False
        log_data.setdefault(MISSING_LOG, {}).setdefault("test", {}).setdefault("errors", []).append(configs)
      else:
        log_reader_data = json.loads(log_text)
        for (testapp, _) in log_reader_data["errors"].items():
          success_or_only_flakiness = False
          log_data.setdefault(testapp, {}).setdefault("test", {}).setdefault("errors", []).append(configs)
        for (testapp, failures) in log_reader_data["failures"].items():
          for (test, _) in failures["failed_tests"].items():
            success_or_only_flakiness = False
            log_data.setdefault(testapp, {}).setdefault("test", {}).setdefault("failures", {}).setdefault(test, []).append(configs)
        for (testapp, flakiness) in log_reader_data["flakiness"].items():
          if flakiness["flaky_tests"].items():
            for (test, _) in flakiness["flaky_tests"].items():
              log_data.setdefault(testapp, {}).setdefault("test", {}).setdefault("flakiness", {}).setdefault(test, []).append(configs)
          else:
            log_data.setdefault(testapp, {}).setdefault("test", {}).setdefault("flakiness", {}).setdefault("CRASH/TIMEOUT", []).append(configs)

  if success_or_only_flakiness and not log_data:
    # No failures and no flakiness occurred, nothing to log.
    return (success_or_only_flakiness, None)

  # if failures (include flakiness) exist:
  # log_results format:
  #   { testapps: {configs: [failed tests]} }
  all_tested_configs = reorganize_all_tested_configs(all_tested_configs)
  logging.info("all_tested_configs: %s", all_tested_configs)
  log_results = reorganize_log(log_data, all_tested_configs)
  log_lines = []
  if markdown:
    log_lines = print_markdown_table(log_results)
    # If outputting Markdown, don't bother justifying the table.
  elif github_log:
    log_lines = print_github_log(log_results)
  else:
    log_lines = print_log(log_results)

  log_summary = "\n".join(log_lines)
  print(log_summary)
  return (success_or_only_flakiness, log_summary)
Ejemplo n.º 42
0
def continuously_collect_trajectories(output_dir,
                                      train_env,
                                      eval_env,
                                      trajectory_dump_dir=None,
                                      env_id=None,
                                      max_trajectories_to_collect=None,
                                      try_abort=True):
    """Instantiates a PPO trainer and collects trajectories."""

    # Make the PPO trainer.
    ppo_trainer = rl_trainers.PPO(
        output_dir=output_dir,
        train_env=train_env,
        eval_env=eval_env,
        trajectory_dump_dir=trajectory_dump_dir,
    )

    # TODO(afrozm): Update base_trainer interface to support SimPLe as well.
    assert isinstance(ppo_trainer, rl_trainers.PPO)

    assert env_id is not None

    # Get an initial policy and wait a forever to get it if needed.
    policy_and_epoch = get_newer_policy_model_file(output_dir,
                                                   wait_forever=True)
    assert policy_and_epoch
    policy_file, epoch = policy_and_epoch
    logging.info('Read initial policy for epoch [%s] -> [%s]', epoch,
                 policy_file)

    # Returns immediately if there is a newer epoch available.
    def is_newer_policy_file_available(epoch_, sleep_time_secs_=0.1):
        return get_newer_policy_model_file(output_dir,
                                           min_epoch=epoch_,
                                           sleep_time_secs=sleep_time_secs_)

    # Does a __done__ file exist?
    def done_file_exists():
        return gfile.exists(os.path.join(output_dir, '__done__'))

    assert 1 == train_env.batch_size
    assert 1 == eval_env.batch_size

    temperature = 1.0

    trajectories_collected = 0

    train_env_trajectory_dump_dir = os.path.join(output_dir,
                                                 'trajectories/train')
    eval_env_trajectory_dump_dir = os.path.join(output_dir,
                                                'trajectories/eval')

    gfile.makedirs(train_env_trajectory_dump_dir)
    gfile.makedirs(eval_env_trajectory_dump_dir)

    while max_trajectories_to_collect is None or trajectories_collected < int(
            max_trajectories_to_collect):
        logging.info('Collecting a trajectory, trajectories_collected = %s',
                     trajectories_collected)

        # Abort function -- if something newever is available, then abort the
        # current computation and reload.

        # Useful if env.step is long.
        def long_abort_fn():
            # We want this to be as quick as possible.
            return (is_newer_policy_file_available(epoch, 0)
                    is not None) or (done_file_exists())

        abort_fn = long_abort_fn if try_abort else None

        # Collect a training trajectory.
        trajs, n_done, unused_timing_info, unused_model_state = (
            ppo_trainer.collect_trajectories(train=True,
                                             temperature=temperature,
                                             abort_fn=abort_fn,
                                             raw_trajectory=True))

        if done_file_exists():
            logging.info('__done__ file found in %s, we are done here.',
                         output_dir)
            break

        if trajs and n_done > 0:
            assert 1 == n_done
            trajectories_collected += n_done

            # Write the trajectory down.
            logging.info(
                'Dumping the collected trajectory, trajectories_collected = %s',
                trajectories_collected)
            dump_trajectory(train_env_trajectory_dump_dir, epoch,
                            env_id, temperature,
                            str(random.randint(0, 2**31 - 1)), trajs)
        else:
            logging.info('Computation was aborted, a new policy is available.')

        # This maybe useless, since `abort_fn` will take care of it. We might want
        # to have this here if abort_fn is False always.
        # Do we have a newer policy?
        policy_file_and_epoch = is_newer_policy_file_available(epoch)
        if policy_file_and_epoch is None:
            # Continue churning out these policies.
            logging.info(
                "We don't have a newer policy, continuing with the old one.")
            continue

        # We have a newer policy, read it and update the parameters.
        policy_file, epoch = policy_file_and_epoch
        logging.info(
            'We have a newer policy epoch [%s], file [%s], updating parameters.',
            epoch, policy_file)
        ppo_trainer.update_optimization_state(output_dir)
        logging.info('Parameters of PPOTrainer updated.')

        # Check that the epochs match.
        assert epoch == ppo_trainer.epoch
Ejemplo n.º 43
0
def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, skip_mismatch=None):
    """Get a var map for restoring from pretrained checkpoints.

  Args:
    ckpt_path: string. A pretrained checkpoint path.
    ckpt_scope: string. Scope name for checkpoint variables.
    var_scope: string. Scope name for model variables.
    skip_mismatch: skip variables if shape mismatch.

  Returns:
    var_map: a dictionary from checkpoint name to model variables.
  """
    logging.info('Init model from checkpoint {}'.format(ckpt_path))
    if not ckpt_scope.endswith('/') or not var_scope.endswith('/'):
        raise ValueError('Please specific scope name ending with /')
    if ckpt_scope.startswith('/'):
        ckpt_scope = ckpt_scope[1:]
    if var_scope.startswith('/'):
        var_scope = var_scope[1:]

    var_map = {}
    # Get the list of vars to restore.
    model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                   scope=var_scope)
    reader = tf.train.load_checkpoint(ckpt_path)
    ckpt_var_name_to_shape = reader.get_variable_to_shape_map()
    ckpt_var_names = set(reader.get_variable_to_shape_map().keys())

    for i, v in enumerate(model_vars):
        if not v.op.name.startswith(var_scope):
            logging.info('skip {} -- does not match scope {}'.format(
                v.op.name, var_scope))
        ckpt_var = ckpt_scope + v.op.name[len(var_scope):]
        if (ckpt_var not in ckpt_var_names
                and v.op.name.endswith('/ExponentialMovingAverage')):
            ckpt_var = ckpt_scope + v.op.name[:-len('/ExponentialMovingAverage'
                                                    )]

        if ckpt_var not in ckpt_var_names:
            if 'Momentum' in ckpt_var or 'RMSProp' in ckpt_var:
                # Skip optimizer variables.
                continue
            if skip_mismatch:
                logging.info('skip {} ({}) -- not in ckpt'.format(
                    v.op.name, ckpt_var))
                continue
            raise ValueError('{} is not in ckpt {}'.format(v.op, ckpt_path))

        if v.shape != ckpt_var_name_to_shape[ckpt_var]:
            if skip_mismatch:
                logging.info('skip {} ({} vs {}) -- shape mismatch'.format(
                    v.op.name, v.shape, ckpt_var_name_to_shape[ckpt_var]))
                continue
            raise ValueError('shape mismatch {} ({} vs {})'.format(
                v.op.name, v.shape, ckpt_var_name_to_shape[ckpt_var]))

        if i < 5:
            # Log the first few elements for sanity check.
            logging.info('Init {} from ckpt var {}'.format(
                v.op.name, ckpt_var))
        var_map[ckpt_var] = v

    return var_map
def main(argv):
    del argv  # unused arg

    # Parse command line arguments.
    seed = FLAGS.seed
    output_dir = FLAGS.output_dir
    data_dir = FLAGS.data_dir
    train_epochs = FLAGS.train_epochs
    checkpoint_interval = FLAGS.checkpoint_interval
    use_bfloat16 = FLAGS.use_bfloat16
    batch_size = FLAGS.batch_size
    eval_batch_size = FLAGS.eval_batch_size
    steps_per_epoch = 1  # TODO(filangel): function of FLAGS
    steps_per_eval = 1  # TODO(filangel): function of FLAGS

    tf.io.gfile.makedirs(output_dir)
    logging.info('Saving checkpoints at %s', output_dir)
    tf.random.set_seed(seed)

    # TODO(filangel): enable TPU support.
    logging.info('Use GPU')
    strategy = tf.distribute.MirroredStrategy()

    if use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(output_dir, 'summaries'))

    dataset_train_builder = utils.load_diabetic_retinopathy_detection(
        split='train', data_dir=data_dir)
    dataset_train = dataset_train_builder.load(batch_size=batch_size)
    dataset_train = strategy.experimental_distribute_dataset(dataset_train)
    dataset_test_builder = utils.load_diabetic_retinopathy_detection(
        split='test', data_dir=data_dir)
    dataset_test = dataset_test_builder.load(batch_size=eval_batch_size)
    dataset_test = strategy.experimental_distribute_dataset(dataset_test)

    with strategy.scope():
        logging.info('Building Keras ResNet-50 model')

        # Shape tuple access depends on number of distributed devices
        try:
            shape_tuple = dataset_train.element_spec['features'].shape
        except AttributeError:  # Multiple TensorSpec in a (nested) PerReplicaSpec.
            tensor_spec_list = dataset_train.element_spec[  # pylint: disable=protected-access
                'features']._flat_tensor_specs
            shape_tuple = tensor_spec_list[0].shape

        model = ub.models.resnet50_deterministic(
            input_shape=shape_tuple.as_list()[1:],
            num_classes=1)  # binary classification task
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())

        optimizer = tf.keras.optimizers.Adam(1e-4)

        metrics = {
            'train/negative_log_likelihood': tf.keras.metrics.Mean(),
            'train/accuracy': tf.keras.metrics.BinaryAccuracy(),
            'train/auc': tf.keras.metrics.AUC(),
            'train/loss': tf.keras.metrics.Mean(),
            'test/negative_log_likelihood': tf.keras.metrics.Mean(),
            'test/accuracy': tf.keras.metrics.BinaryAccuracy(),
            'test/auc': tf.keras.metrics.AUC()
        }
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(output_dir)
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)

    def train_step(iterator):
        """Training step function."""
        def step_fn(inputs):
            """Per-replica step function."""
            images = inputs['features']
            labels = inputs['labels']
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.binary_crossentropy(y_true=tf.expand_dims(
                        labels, axis=-1),
                                                        y_pred=logits,
                                                        from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.squeeze(tf.nn.sigmoid(logits))
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auc'].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    def test_step(iterator):
        """Evaluation step function."""
        def step_fn(inputs):
            """Per-replica step function."""
            images = inputs['features']
            labels = inputs['labels']
            logits = model(images, training=True)
            if use_bfloat16:
                logits = tf.cast(logits, tf.float32)

            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(y_true=tf.expand_dims(
                    labels, axis=-1),
                                                    y_pred=logits,
                                                    from_logits=True))
            probs = tf.squeeze(tf.nn.sigmoid(logits))
            metrics['test/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['test/accuracy'].update_state(labels, probs)
            metrics['test/auc'].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    train_iterator = iter(dataset_train)
    test_iterator = iter(dataset_test)
    start_time = time.time()
    for epoch in range(train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        for step in range(steps_per_epoch):
            train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1, train_epochs,
                           steps_per_sec, eta_seconds / 60, time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        for step in range(steps_per_eval):
            if step % 20 == 0:
                logging.info('Starting to run eval step %s of epoch: %s', step,
                             epoch)
            test_step(test_iterator)

        logging.info(
            'Train Loss (NLL+L2): %.4f, Accuracy: %.2f%%, AUC: %.2f%%',
            metrics['train/loss'].result(),
            metrics['train/accuracy'].result() * 100,
            metrics['train/auc'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%, AUC: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100,
                     metrics['test/auc'].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if checkpoint_interval > 0 and (epoch + 1) % checkpoint_interval == 0:
            checkpoint_name = checkpoint.save(
                os.path.join(output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(output_dir, 'checkpoint'), )
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
Ejemplo n.º 45
0
def log(s, stdout=True):
    logging.info(s)
    if stdout:
        print(s)
Ejemplo n.º 46
0
def main(unused_argv):
    del unused_argv
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
    train_input_fn = functools.partial(data_utils.get_squad_input_data,
                                       FLAGS.train_batch_size, FLAGS.seq_len,
                                       FLAGS.query_len, strategy, True,
                                       FLAGS.train_tfrecord_path)

    test_input_fn = functools.partial(data_utils.get_squad_input_data,
                                      FLAGS.test_batch_size, FLAGS.seq_len,
                                      FLAGS.query_len, strategy, False,
                                      FLAGS.test_tfrecord_path)

    total_training_steps = FLAGS.train_steps
    steps_per_loop = FLAGS.iterations
    eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)

    optimizer, learning_rate_fn = optimization.create_optimizer(
        FLAGS.learning_rate,
        total_training_steps,
        FLAGS.warmup_steps,
        adam_epsilon=FLAGS.adam_epsilon)
    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["start_n_top"] = FLAGS.start_n_top
    input_meta_data["end_n_top"] = FLAGS.end_n_top
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    input_meta_data["predict_dir"] = FLAGS.predict_dir
    input_meta_data["n_best_size"] = FLAGS.n_best_size
    input_meta_data["max_answer_length"] = FLAGS.max_answer_length
    input_meta_data["test_batch_size"] = FLAGS.test_batch_size
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["mem_len"] = FLAGS.mem_len
    model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
                                 FLAGS.start_n_top, FLAGS.end_n_top)
    eval_examples = squad_utils.read_squad_examples(FLAGS.predict_file,
                                                    is_training=False)
    if FLAGS.test_feature_path:
        logging.info("start reading pickle file...")
        with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f:
            eval_features = pickle.load(f)
        logging.info("finishing reading pickle file...")
    else:
        sp_model = spm.SentencePieceProcessor()
        sp_model.LoadFromSerializedProto(
            tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read())
        spm_basename = os.path.basename(FLAGS.spiece_model_file)
        eval_features = squad_utils.create_eval_data(
            spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
            FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased)

    with tf.io.gfile.GFile(FLAGS.predict_file) as f:
        original_data = json.load(f)["data"]
    eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                                eval_examples, eval_features, original_data,
                                eval_steps, input_meta_data)

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=eval_fn,
                         metric_fn=None,
                         train_input_fn=train_input_fn,
                         init_checkpoint=FLAGS.init_checkpoint,
                         init_from_transformerxl=FLAGS.init_from_transformerxl,
                         total_training_steps=total_training_steps,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir,
                         save_steps=FLAGS.save_steps)
Ejemplo n.º 47
0
def write_predictions(all_examples,
                      all_features,
                      all_results,
                      n_best_size,
                      max_answer_length,
                      do_lower_case,
                      output_prediction_file,
                      output_nbest_file,
                      output_null_log_odds_file,
                      version_2_with_negative=False,
                      null_score_diff_threshold=0.0,
                      verbose=False):
    """Write final predictions to the json file and log-odds of null if needed."""
    logging.info("Writing predictions to: %s", (output_prediction_file))
    logging.info("Writing nbest to: %s", (output_nbest_file))

    example_index_to_features = collections.defaultdict(list)
    for feature in all_features:
        example_index_to_features[feature.example_index].append(feature)
    unique_id_to_result = {}
    for result in all_results:
        unique_id_to_result[result.unique_id] = result

    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction", [
            "feature_index", "start_index", "end_index", "start_logit",
            "end_logit"
        ])

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    scores_diff_json = collections.OrderedDict()

    for (example_index, example) in enumerate(all_examples):
        features = example_index_to_features[example_index]

        prelim_predictions = []
        # keep track of the minimum score of null start+end of position 0
        score_null = 1000000  # large and positive
        min_null_feature_index = 0  # the paragraph slice with min mull score
        null_start_logit = 0  # the start logit at the slice with min null score
        null_end_logit = 0  # the end logit at the slice with min null score
        for (feature_index, feature) in enumerate(features):
            result = unique_id_to_result[feature.unique_id]
            start_indexes = _get_best_indexes(result.start_logits, n_best_size)
            end_indexes = _get_best_indexes(result.end_logits, n_best_size)
            # if we could have irrelevant answers, get the min score of irrelevant
            if version_2_with_negative:
                feature_null_score = result.start_logits[
                    0] + result.end_logits[0]
                if feature_null_score < score_null:
                    score_null = feature_null_score
                    min_null_feature_index = feature_index
                    null_start_logit = result.start_logits[0]
                    null_end_logit = result.end_logits[0]
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # We could hypothetically create invalid predictions, e.g., predict
                    # that the start of the span is in the question. We throw out all
                    # invalid predictions.
                    if start_index >= len(feature.tokens):
                        continue
                    if end_index >= len(feature.tokens):
                        continue
                    if start_index not in feature.token_to_orig_map:
                        continue
                    if end_index not in feature.token_to_orig_map:
                        continue
                    if not feature.token_is_max_context.get(
                            start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    prelim_predictions.append(
                        _PrelimPrediction(
                            feature_index=feature_index,
                            start_index=start_index,
                            end_index=end_index,
                            start_logit=result.start_logits[start_index],
                            end_logit=result.end_logits[end_index]))

        if version_2_with_negative:
            prelim_predictions.append(
                _PrelimPrediction(feature_index=min_null_feature_index,
                                  start_index=0,
                                  end_index=0,
                                  start_logit=null_start_logit,
                                  end_logit=null_end_logit))
        prelim_predictions = sorted(prelim_predictions,
                                    key=lambda x:
                                    (x.start_logit + x.end_logit),
                                    reverse=True)

        _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "NbestPrediction", ["text", "start_logit", "end_logit"])

        seen_predictions = {}
        nbest = []
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            feature = features[pred.feature_index]
            if pred.start_index > 0:  # this is a non-null prediction
                tok_tokens = feature.tokens[pred.start_index:(pred.end_index +
                                                              1)]
                orig_doc_start = feature.token_to_orig_map[pred.start_index]
                orig_doc_end = feature.token_to_orig_map[pred.end_index]
                orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end +
                                                                 1)]
                tok_text = " ".join(tok_tokens)

                # De-tokenize WordPieces that have been split off.
                tok_text = tok_text.replace(" ##", "")
                tok_text = tok_text.replace("##", "")

                # Clean whitespace
                tok_text = tok_text.strip()
                tok_text = " ".join(tok_text.split())
                orig_text = " ".join(orig_tokens)

                final_text = get_final_text(tok_text,
                                            orig_text,
                                            do_lower_case,
                                            verbose=verbose)
                if final_text in seen_predictions:
                    continue

                seen_predictions[final_text] = True
            else:
                final_text = ""
                seen_predictions[final_text] = True

            nbest.append(
                _NbestPrediction(text=final_text,
                                 start_logit=pred.start_logit,
                                 end_logit=pred.end_logit))

        # if we didn't inlude the empty option in the n-best, inlcude it
        if version_2_with_negative:
            if "" not in seen_predictions:
                nbest.append(
                    _NbestPrediction(text="",
                                     start_logit=null_start_logit,
                                     end_logit=null_end_logit))
        # In very rare edge cases we could have no valid predictions. So we
        # just create a nonce prediction in this case to avoid failure.
        if not nbest:
            nbest.append(
                _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

        assert len(nbest) >= 1

        total_scores = []
        best_non_null_entry = None
        for entry in nbest:
            total_scores.append(entry.start_logit + entry.end_logit)
            if not best_non_null_entry:
                if entry.text:
                    best_non_null_entry = entry

        probs = _compute_softmax(total_scores)

        nbest_json = []
        for (i, entry) in enumerate(nbest):
            output = collections.OrderedDict()
            output["text"] = entry.text
            output["probability"] = probs[i]
            output["start_logit"] = entry.start_logit
            output["end_logit"] = entry.end_logit
            nbest_json.append(output)

        assert len(nbest_json) >= 1

        if not version_2_with_negative:
            all_predictions[example.qas_id] = nbest_json[0]["text"]
        else:
            # pytype: disable=attribute-error
            # predict "" iff the null score - the score of best non-null > threshold
            score_diff = score_null - best_non_null_entry.start_logit - (
                best_non_null_entry.end_logit)
            scores_diff_json[example.qas_id] = score_diff
            if score_diff > null_score_diff_threshold:
                all_predictions[example.qas_id] = ""
            else:
                all_predictions[example.qas_id] = best_non_null_entry.text
            # pytype: enable=attribute-error

        all_nbest_json[example.qas_id] = nbest_json

    with tf.io.gfile.GFile(output_prediction_file, "w") as writer:
        writer.write(json.dumps(all_predictions, indent=4) + "\n")

    with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
        writer.write(json.dumps(all_nbest_json, indent=4) + "\n")

    if version_2_with_negative:
        with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer:
            writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
Ejemplo n.º 48
0
def run_evaluation(strategy, test_input_fn, eval_examples, eval_features,
                   original_data, eval_steps, input_meta_data, model,
                   current_step, eval_summary_writer):
    """Run evaluation for SQUAD task.

  Args:
    strategy: distribution strategy.
    test_input_fn: input function for evaluation data.
    eval_examples: tf.Examples of the evaluation set.
    eval_features: Feature objects of the evaluation set.
    original_data: The original json data for the evaluation set.
    eval_steps: total number of evaluation steps.
    input_meta_data: input meta data.
    model: keras model object.
    current_step: current training step.
    eval_summary_writer: summary writer used to record evaluation metrics.

  Returns:
    A float metric, F1 score.
  """
    def _test_step_fn(inputs):
        """Replicated validation step."""

        inputs["mems"] = None
        res = model(inputs, training=False)
        return res, inputs["unique_ids"]

    @tf.function
    def _run_evaluation(test_iterator):
        """Runs validation steps."""
        res, unique_ids = strategy.run(_test_step_fn,
                                       args=(next(test_iterator), ))
        return res, unique_ids

    test_iterator = data_utils.get_input_iterator(test_input_fn, strategy)
    cur_results = []
    for _ in range(eval_steps):
        results, unique_ids = _run_evaluation(test_iterator)
        unique_ids = strategy.experimental_local_results(unique_ids)

        for result_key in results:
            results[result_key] = (strategy.experimental_local_results(
                results[result_key]))
        for core_i in range(strategy.num_replicas_in_sync):
            bsz = int(input_meta_data["test_batch_size"] /
                      strategy.num_replicas_in_sync)
            for j in range(bsz):
                result = {}
                for result_key in results:
                    result[result_key] = results[result_key][core_i].numpy()[j]
                result["unique_ids"] = unique_ids[core_i].numpy()[j]
                # We appended a fake example into dev set to make data size can be
                # divided by test_batch_size. Ignores this fake example during
                # evaluation.
                if result["unique_ids"] == 1000012047:
                    continue
                unique_id = int(result["unique_ids"])

                start_top_log_probs = ([
                    float(x) for x in result["start_top_log_probs"].flat
                ])
                start_top_index = [
                    int(x) for x in result["start_top_index"].flat
                ]
                end_top_log_probs = ([
                    float(x) for x in result["end_top_log_probs"].flat
                ])
                end_top_index = [int(x) for x in result["end_top_index"].flat]

                cls_logits = float(result["cls_logits"].flat[0])
                cur_results.append(
                    squad_utils.RawResult(
                        unique_id=unique_id,
                        start_top_log_probs=start_top_log_probs,
                        start_top_index=start_top_index,
                        end_top_log_probs=end_top_log_probs,
                        end_top_index=end_top_index,
                        cls_logits=cls_logits))
                if len(cur_results) % 1000 == 0:
                    logging.info("Processing example: %d", len(cur_results))

    output_prediction_file = os.path.join(input_meta_data["predict_dir"],
                                          "predictions.json")
    output_nbest_file = os.path.join(input_meta_data["predict_dir"],
                                     "nbest_predictions.json")
    output_null_log_odds_file = os.path.join(input_meta_data["predict_dir"],
                                             "null_odds.json")

    results = squad_utils.write_predictions(
        eval_examples, eval_features, cur_results,
        input_meta_data["n_best_size"], input_meta_data["max_answer_length"],
        output_prediction_file, output_nbest_file, output_null_log_odds_file,
        original_data, input_meta_data["start_n_top"],
        input_meta_data["end_n_top"])

    # Log current results.
    log_str = "Result | "
    for key, val in results.items():
        log_str += "{} {} | ".format(key, val)
    logging.info(log_str)
    with eval_summary_writer.as_default():
        tf.summary.scalar("best_f1", results["best_f1"], step=current_step)
        tf.summary.scalar("best_exact",
                          results["best_exact"],
                          step=current_step)
        eval_summary_writer.flush()
    return results["best_f1"]
Ejemplo n.º 49
0
  def _download(
      self,
      resource: Union[str,
                      resource_lib.Resource]) -> promise.Promise[ReadOnlyPath]:
    """Download resource, returns Promise->path to downloaded file.

    This function:

    1. Reuse cache (`_get_cached_path`) or download the file
    2. Register or validate checksums (`_register_or_validate_checksums`)
    3. Rename download to final path (`_rename_and_get_final_dl_path`)

    Args:
      resource: The URL to download.

    Returns:
      path: The path to the downloaded resource.
    """
    # Normalize the input
    if isinstance(resource, str):
      resource = resource_lib.Resource(url=resource)
    url = resource.url

    expected_url_info = self._url_infos.get(url)

    # 3 possible destinations for the path:
    # * In `manual_dir` (manually downloaded data)
    # * In `downloads/url_path` (checksum unknown)
    # * In `downloads/checksum_path` (checksum registered)
    manually_downloaded_path = _get_manually_downloaded_path(
        manual_dir=self._manual_dir,
        expected_url_info=expected_url_info,
    )
    url_path = self._get_dl_path(
        url, sha256=hashlib.sha256(url.encode('utf-8')).hexdigest())
    checksum_path = self._get_dl_path(
        url, sha256=expected_url_info.checksum) if expected_url_info else None

    # Get the cached path and url_info (if they exists)
    dl_result = _get_cached_path(
        manually_downloaded_path=manually_downloaded_path,
        checksum_path=checksum_path,
        url_path=url_path,
        expected_url_info=expected_url_info,
    )
    if dl_result.path and not self._force_download:  # Download was cached
      logging.info(
          f'Skipping download of {url}: File cached in {dl_result.path}')
      # Still update the progression bar to indicate the file was downloaded
      self._downloader.increase_tqdm(dl_result)
      future = promise.Promise.resolve(dl_result)
    else:
      # Download in an empty tmp directory (to avoid name collisions)
      # `download_tmp_dir` is cleaned-up in `_rename_and_get_final_dl_path`
      dirname = f'{resource_lib.get_dl_dirname(url)}.tmp.{uuid.uuid4().hex}'
      download_tmp_dir = self._download_dir / dirname
      download_tmp_dir.mkdir()
      logging.info(f'Downloading {url} into {download_tmp_dir}...')
      future = self._downloader.download(
          url, download_tmp_dir, verify=self._verify_ssl)

    # Post-process the result
    return future.then(lambda dl_result: self._register_or_validate_checksums(  # pylint: disable=g-long-lambda
        url=url,
        path=dl_result.path,
        computed_url_info=dl_result.url_info,
        expected_url_info=expected_url_info,
        checksum_path=checksum_path,
        url_path=url_path,
    ))
Ejemplo n.º 50
0
def run_random_search(max_num_programs, checkpoint_dir, task_eval_fn,
                      timestep_limit):
    """Run uniform random search routine.

  Randomly samples programs from a uniform distribution until either a valid
  program is found, or the maximum NPE is reached. Results are written to disk
  and returned.

  Args:
    max_num_programs: Maximum NPE (number of programs executed). If no solution
        is found after this many programs are tried, the run is stopped and
        considered a failure.
    checkpoint_dir: Where to save state during the run.
    task_eval_fn: Function that maps code string to result containing total
        reward and info about success.
    timestep_limit: Maximum length of code strings.

  Returns:
    ga_lib.GaResult namedtuple instance. This contains the best code and highest
    reward found.
  """
    checkpoint_file = os.path.join(checkpoint_dir, 'random_search.txt')
    num_programs_seen = 0
    found_solution = False
    best_code = ''
    best_reward = 0.0
    if tf.gfile.Exists(checkpoint_file):
        try:
            with tf.gfile.FastGFile(checkpoint_file, 'r') as f:
                lines = list(f)
                num_programs_seen = int(lines[0])
                found_solution = bool(int(lines[1]))
                if found_solution:
                    best_code = lines[2]
                    best_reward = float(lines[3])
        except:  # pylint: disable=bare-except
            pass

    while not found_solution and num_programs_seen < max_num_programs:
        if num_programs_seen % 1000 == 0:
            logging.info('num_programs_seen = %d', num_programs_seen)
            with tf.gfile.FastGFile(checkpoint_file, 'w') as f:
                f.write(str(num_programs_seen) + '\n')
                f.write(str(int(found_solution)) + '\n')

        code = np.random.choice(ga_lib.GENES, timestep_limit).tolist()
        res = task_eval_fn(code)
        found_solution = res.correct
        num_programs_seen += 1

        if found_solution:
            best_code = ''.join(code)
            best_reward = res.reward

    logging.info('num_programs_seen = %d', num_programs_seen)
    logging.info('found solution: %s', found_solution)
    with tf.gfile.FastGFile(checkpoint_file, 'w') as f:
        f.write(str(num_programs_seen) + '\n')
        f.write(str(int(found_solution)) + '\n')
        if found_solution:
            f.write(best_code + '\n')
            f.write(str(best_reward) + '\n')

    return ga_lib.GaResult(population=[],
                           best_code=best_code,
                           reward=best_reward,
                           solution_found=found_solution,
                           generations=num_programs_seen,
                           num_programs=num_programs_seen,
                           max_generations=max_num_programs,
                           max_num_programs=max_num_programs)
def main(arg=None):
    st()
    FLAGS = flags.FLAGS  # pylint: disable=invalid-name,redefined-outer-name
    config = FLAGS
    FLAGS.__dict__['config'] = config

    FLAGS.logdir = FLAGS.logdir.format(name=FLAGS.name)

    logdir = FLAGS.logdir
    logging.info('logdir: %s', logdir)

    if os.path.exists(logdir) and FLAGS.overwrite:
        logging.info('"overwrite" is set to True. Deleting logdir at "%s".',
                     logdir)
        shutil.rmtree(logdir)

    # Build the graph
    with tf.Graph().as_default():

        model_dict = model_config.get(FLAGS)
        data_dict = data_config.get(FLAGS)

        lr = model_dict.lr
        opt = model_dict.opt
        model = model_dict.model
        trainset = data_dict.trainset
        validset = data_dict.validset

        lr = tf.convert_to_tensor(lr)
        tf.summary.scalar('learning_rate', lr)

        # Training setup
        global_step = tf.train.get_or_create_global_step()

        # Optimisation target
        validset = tools.maybe_convert_dataset(validset)
        trainset = tools.maybe_convert_dataset(trainset)
        target, gvs = model.make_target(trainset, opt)

        if gvs is None:
            gvs = opt.compute_gradients(target)

        suppress_inf_and_nans = (config.grad_value_clip > 0
                                 or config.grad_norm_clip > 0)
        report = tools.gradient_summaries(gvs, suppress_inf_and_nans)
        report['target'] = target
        valid_report = dict()

        gvs = tools.clip_gradients(gvs,
                                   value_clip=config.grad_value_clip,
                                   norm_clip=config.grad_norm_clip)

        try:
            report.update(model.make_report(trainset))
            valid_report.update(model.make_report(validset))
        except AttributeError:
            logging.warning('Model %s has no "make_report" method.',
                            str(model))
            raise

        plot_dict, plot_params = None, None
        if config.plot:
            try:
                plot_dict, plot_params = model.make_plot(trainset, 'train')
                valid_plot, valid_params = model.make_plot(validset, 'valid')

                plot_dict.update(valid_plot)
                if plot_params is not None:
                    plot_params.update(valid_params)

            except AttributeError:
                logging.warning('Model %s has no "make_plot" method.',
                                str(model))

        report = tools.scalar_logs(report,
                                   config.ema,
                                   'train',
                                   global_update=config.global_ema_update)
        report['lr'] = lr
        valid_report = tools.scalar_logs(
            valid_report,
            config.ema,
            'valid',
            global_update=config.global_ema_update)

        reports_keys = sorted(report.keys())

        def _format(k):
            if k in ('lr', 'learning_rate'):
                return '.2E'
            return '.3f'

        report_template = ', '.join([
            '{}: {}{}:{}{}'.format(k, '{', k, _format(k), '}')
            for k in reports_keys
        ])

        logging.info('Trainable variables:')
        tools.log_variables_by_scope()

        # inspect gradients
        for g, v in gvs:
            if g is None:
                logging.warning('No gradient for variable: %s.', v.name)

        tools.log_num_params()

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if FLAGS.check_numerics:
            update_ops += [tf.add_check_numerics_ops()]

        with tf.control_dependencies(update_ops):
            train_step = opt.apply_gradients(gvs, global_step=global_step)

        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True

        with tf.train.SingularMonitoredSession(hooks=create_hooks(
                FLAGS, plot_dict, plot_params),
                                               checkpoint_dir=logdir,
                                               config=sess_config) as sess:

            train_itr, _ = sess.run([global_step, update_ops])
            train_tensors = [global_step, train_step]
            report_tensors = [report, valid_report]
            all_tensors = report_tensors + train_tensors

            while train_itr < config.max_train_steps:
                # print('Doing train itr %d.' % train_itr)

                if train_itr % config.report_loss_steps == 0:
                    report_vals, valid_report_vals, train_itr, _ = sess.run(
                        all_tensors)

                    logging.info('')
                    logging.info('train:')
                    logging.info('#%s: %s', train_itr,
                                 report_template.format(**report_vals))

                    logging.info('valid:')
                    valid_logs = dict(report_vals)
                    valid_logs.update(valid_report_vals)
                    logging.info('#%s: %s', train_itr,
                                 report_template.format(**valid_logs))

                    vals_to_check = list(report_vals.values())
                    if (np.isnan(vals_to_check).any()
                            or np.isnan(vals_to_check).any()):
                        logging.fatal('NaN in reports: %s; breaking...',
                                      report_template.format(**report_vals))

                else:
                    train_itr, _ = sess.run(train_tensors)
Ejemplo n.º 52
0
def run_training(
        config=None,
        tuner=None,
        logdir=None,
        trial_name=None,  # pylint: disable=unused-argument
        is_chief=True):
    """Do all training runs.

  This is the top level training function for policy gradient based models.
  Run this from the main function.

  Args:
    config: config_lib.Config instance containing global config (agent and
        environment hparams). If None, config will be parsed from FLAGS.config.
    tuner: (unused) A tuner instance. Leave as None if not tuning.
    logdir: Parent directory where all data from all runs will be written. If
        None, FLAGS.logdir will be used.
    trial_name: (unused) If tuning, set this to a unique string that identifies
        this trial. If `tuner` is not None, this also must be set.
    is_chief: True if this worker is the chief.

  Returns:
    List of results dicts which were written to disk. Each training run gets a
    results dict. Results dict contains metrics, i.e. (name, value) pairs which
    give information about the training run.

  Raises:
    ValueError: If FLAGS.num_workers does not divide FLAGS.num_repetitions.
    ValueError: If results dicts read from disk contain invalid data.
  """
    if not config:
        # If custom config is not given, get it from flags.
        config = defaults.default_config_with_updates(FLAGS.config)
    if not logdir:
        logdir = FLAGS.logdir

    if FLAGS.num_repetitions % FLAGS.num_workers != 0:
        raise ValueError('Number of workers must divide number of repetitions')
    num_local_reps = FLAGS.num_repetitions // FLAGS.num_workers
    logging.info('Running %d reps globally.', FLAGS.num_repetitions)
    logging.info('This worker will run %d local reps.', num_local_reps)
    if FLAGS.max_npe:
        max_generations = FLAGS.max_npe // config.batch_size
        logging.info('Max samples per rep: %d', FLAGS.max_npe)
        logging.info('Max generations per rep: %d', max_generations)
    else:
        max_generations = sys.maxsize
        logging.info('Running unlimited generations.')

    assert FLAGS.num_workers > 0
    logging.info('Starting experiment. Directory: "%s"', logdir)
    results = results_lib.Results(logdir, FLAGS.task_id)
    local_results_list = results.read_this_shard()
    if local_results_list:
        if local_results_list[0]['max_npe'] != FLAGS.max_npe:
            raise ValueError(
                'Cannot resume training. Max-NPE changed. Was %s, now %s',
                local_results_list[0]['max_npe'], FLAGS.max_npe)
        if local_results_list[0][
                'max_global_repetitions'] != FLAGS.num_repetitions:
            raise ValueError(
                'Cannot resume training. Number of repetitions changed. Was %s, '
                'now %s', local_results_list[0]['max_global_repetitions'],
                FLAGS.num_repetitions)
    start_rep = len(local_results_list)

    for rep in range(start_rep, num_local_reps):
        global_rep = num_local_reps * FLAGS.task_id + rep
        logging.info('Starting repetition: Rep = %d. (global rep = %d)', rep,
                     global_rep)

        # Save data for each rep, like checkpoints, goes into separate folders.
        run_dir = os.path.join(logdir, 'run_%d' % global_rep)

        if not tf.gfile.IsDirectory(run_dir):
            tf.gfile.MakeDirs(run_dir)
        checkpoint_writer = CheckpointWriter(run_dir,
                                             population_size=config.batch_size)

        data_manager = data.DataManager(config, run_number=global_rep)
        task_eval_fn = ga_lib.make_task_eval_fn(data_manager.rl_task)

        if config.agent.algorithm == 'rand':
            logging.info('Running random search.')
            assert FLAGS.max_npe
            result = run_random_search(FLAGS.max_npe, run_dir, task_eval_fn,
                                       config.timestep_limit)
        else:
            assert config.agent.algorithm == 'ga'
            logging.info('Running genetic algorithm.')
            pop = ga_lib.make_population(ga_lib.random_individual(
                config.timestep_limit),
                                         n=config.batch_size)
            hof = utils.MaxUniquePriorityQueue(2)  # Hall of fame.
            result = ga_lib.ga_loop(pop,
                                    cxpb=config.agent.crossover_rate,
                                    mutpb=config.agent.mutation_rate,
                                    task_eval_fn=task_eval_fn,
                                    ngen=max_generations,
                                    halloffame=hof,
                                    checkpoint_writer=checkpoint_writer)

        logging.info('Finished rep. Num gens: %d', result.generations)

        results_dict = {
            'max_npe': FLAGS.max_npe,
            'batch_size': config.batch_size,
            'max_batches': FLAGS.max_npe // config.batch_size,
            'npe': result.num_programs,
            'max_global_repetitions': FLAGS.num_repetitions,
            'max_local_repetitions': num_local_reps,
            'code_solution': result.best_code if result.solution_found else '',
            'best_reward': result.reward,
            'num_batches': result.generations,
            'found_solution': result.solution_found,
            'task': data_manager.task_name,
            'global_rep': global_rep
        }
        logging.info('results_dict: %s', results_dict)
        results.append(results_dict)

    if is_chief:
        logging.info(
            'Worker is chief. Waiting for all workers to finish so that results '
            'can be reported to the tuner.')

        global_results_list, shard_stats = results.read_all(
            num_shards=FLAGS.num_workers)
        while not all(s.finished for s in shard_stats):
            logging.info(
                'Still waiting on these workers: %s', ', '.join([
                    '%d (%d reps left)' %
                    (i, s.max_local_reps - s.num_local_reps_completed)
                    for i, s in enumerate(shard_stats) if not s.finished
                ]))
            sleep(60)
            global_results_list, shard_stats = results.read_all(
                num_shards=FLAGS.num_workers)

        logging.info(
            '%d results obtained. Chief worker is exiting the experiment.',
            len(global_results_list))

        return global_results_list
  def _test_slicing(self, data, dist):
    batch_shape = dist.batch_shape
    slices = data.draw(valid_slices(batch_shape))
    slice_str = 'dist[{}]'.format(', '.join(stringify_slices(slices)))
    logging.info('slice used: %s', slice_str)
    # Make sure the slice string appears in Hypothesis' attempted example log,
    # by drawing and discarding it.
    data.draw(hps.just(slice_str))
    if not slices:  # Nothing further to check.
      return
    sliced_zeros = np.zeros(batch_shape)[slices]
    sliced_dist = dist[slices]

    # Check that slicing modifies batch shape as expected.
    self.assertAllEqual(sliced_zeros.shape, sliced_dist.batch_shape)

    # Check that sampling of sliced distributions executes.
    try:
      seed = data.draw(
          hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0))
      samples = self.evaluate(dist.sample(seed=maybe_seed(seed)))

      if not sliced_zeros.size:
        # TODO(b/128924708): Fix distributions that fail on degenerate empty
        #     shapes, e.g. Multinomial, DirichletMultinomial, ...
        return

      sliced_samples = self.evaluate(sliced_dist.sample(seed=maybe_seed(seed)))
    except NotImplementedError as e:
      raise
    except tf.errors.UnimplementedError as e:
      if 'Unhandled input dimensions' in str(e) or 'rank not in' in str(e):
        # Some cases can fail with 'Unhandled input dimensions \d+' or
        # 'inputs rank not in [0,6]: \d+'
        return
      raise

    # Come up with the slices for samples (which must also include event dims).
    sample_slices = (
        tuple(slices) if isinstance(slices, collections.Sequence) else
        (slices,))
    if Ellipsis not in sample_slices:
      sample_slices += (Ellipsis,)
    sample_slices += tuple([slice(None)] *
                           tensorshape_util.rank(dist.event_shape))

    # Report sub-sliced samples (on which we compare log_prob) to hypothesis.
    data.draw(hps.just(samples[sample_slices]))

    # Check that sampling a sliced distribution produces the same shape as
    # slicing the samples from the original.
    self.assertAllEqual(samples[sample_slices].shape, sliced_samples.shape)

    # Check that a sliced distribution can compute the log_prob of its own
    # samples (up to numerical validation errors).
    try:
      try:
        lp = self.evaluate(dist.log_prob(samples))
      except tf.errors.InvalidArgumentError:
        # TODO(b/129271256): d.log_prob(d.sample()) should not fail
        #     validate_args checks.
        # We only tolerate this case for the non-sliced dist.
        return
      sliced_lp = self.evaluate(sliced_dist.log_prob(samples[sample_slices]))
    except tf.errors.UnimplementedError as e:
      if 'Unhandled input dimensions' in str(e) or 'rank not in' in str(e):
        # Some cases can fail with 'Unhandled input dimensions \d+' or
        # 'inputs rank not in [0,6]: \d+'
        return
      raise

    # Check that the sliced dist's log_prob agrees with slicing the original's
    # log_prob.
    # TODO(b/128708201): Better numerics for Geometric/Beta?
    # Eigen can return quite different results for packet vs non-packet ops.
    # To work around this, we use a much larger rtol for the last 3
    # (assuming packet size 4) elements.
    packetized_lp = lp[slices].reshape(-1)[:-3]
    packetized_sliced_lp = sliced_lp.reshape(-1)[:-3]
    rtol = (0.1 if any(
        x in dist.name for x in ('Geometric', 'Beta', 'Dirichlet')) else 0.02)
    self.assertAllClose(packetized_lp, packetized_sliced_lp, rtol=rtol)
    possibly_nonpacket_lp = lp[slices].reshape(-1)[-3:]
    possibly_nonpacket_sliced_lp = sliced_lp.reshape(-1)[-3:]
    self.assertAllClose(
        possibly_nonpacket_lp, possibly_nonpacket_sliced_lp,
        rtol=0.4, atol=1e-4)
Ejemplo n.º 54
0
    def join(self, timeout=_DEFAULT_TIMEOUT_SEC):
        """Joins all the processes with timeout.

    If any of the subprocesses does not exit approximately after `timeout`
    seconds has passed after `join` call, this raises a
    `SubprocessTimeoutError`.

    Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to
    log the stack traces of the subprocesses when they exit. However, this
    results in timeout when the test runs with tsan (thread sanitizer); if tsan
    is being run on the test targets that rely on timeout to assert information,
    `MultiProcessRunner.terminate_all()` must be called after `join()`, before
    the test exits, so the subprocesses are terminated with SIGKILL, and data
    race is removed.

    Args:
      timeout: optional integer or `None`. If provided as an integer, and not
      all processes report status within roughly `timeout` seconds, a
      `SubprocessTimeoutError` exception will be raised. If `None`, `join` never
      times out.

    Returns:
      A MultiProcessRunnerResult object, which has two attributes,
      `return_value` and `stdout`. `return_value` always contains the return
      values from the subprocesses. If `return_output` argument is True at
      `__init__`, `stdout` is available that contains a list of all messages
      from subprocesses' stdout and stderr.

    Raises:
      SubprocessTimeoutError: if not all processes report status approximately
        within `timeout` seconds. When this is raised, a
        `MultiProcessRunnerResult` object can be retrieved by
        `SubprocessTimeoutError`'s mpr_result attribute, which has the same
        structure as above 'Returns' section describes.
      UnexpectedSubprocessExitError: If any of the subprocesses did not exit
        properly (for example, they exit on SIGTERM or SIGKILL signal). When
        this is raised, a `MultiProcessRunnerResult` object can be retrieved by
        `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the
        same structure as above 'Returns' section describes. If `max_run_time`
        is not `None`, it is expected that some subprocesses may be
        force-killed when `max_run_time` is up, and this is raised in those
        cases.
      Exception: if there is an Exception propagated from any subprocess. When
        this is raised, a `MultiProcessRunnerResult` object can be retrieved by
        `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the
        same structure as above 'Returns' section describes.
    """
        if timeout and not isinstance(timeout, int):
            raise ValueError('`timeout` must be an integer or `None`.')
        with self._process_lock:
            if self._joined:
                raise ValueError("MultiProcessRunner can't be joined twice.")
            self._joined = True

        self._watchdog_thread.join(timeout)
        if self._watchdog_thread.is_alive():
            # Timeout. Force termination to dump worker processes stack trace.
            with self._process_lock:
                self._auto_restart = False
            logging.error(
                'Timeout when joining for child processes. Terminating...')
            self.terminate_all(sig=signal.SIGTERM)
            # Wait for the processes to terminate by themselves first, so they have a
            # chance to dump stacktraces. After _FORCE_KILL_WAIT_SEC, we SIGKILL them.
            self._watchdog_thread.join(_FORCE_KILL_WAIT_SEC)
            if self._watchdog_thread.is_alive():
                logging.error('Timeout when waiting for child processes to '
                              'print stacktrace. Sending SIGKILL...')
                self.terminate_all()
                self._watchdog_thread.join()
            process_statuses = self._get_process_statuses()
            self._reraise_if_subprocess_error(process_statuses)
            raise SubprocessTimeoutError(
                'One or more subprocesses timed out, where timeout was set to {}s. '
                'Please change the `timeout` argument for '
                '`MultiProcessRunner.join()` or `multi_process_runner.run()` '
                'if it should be adjusted.'.format(timeout),
                self._get_mpr_result(process_statuses))

        for (task_type, task_id), p in self._processes.items():
            logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode)

        process_statuses = self._get_process_statuses()
        self._reraise_if_subprocess_error(process_statuses)

        # Checking all the processes that are expected to exit properly.
        for (task_type, task_id), p in self._processes.items():
            # Successfully exiting process has exit code 0. We ignore processes that
            # are terminated.
            assert p.exitcode is not None
            if (p.exitcode > 0
                    and (task_type, task_id) not in self._terminated):
                raise UnexpectedSubprocessExitError(
                    'Subprocess %s-%d exited with exit code %s. See logs for details.'
                    % (task_type, task_id, p.exitcode),
                    self._get_mpr_result(process_statuses))

        logging.info('Joining log reading threads.')
        for thread in self._reading_threads:
            thread.join()
        logging.info('Joined log reading threads.')

        # Clear the alarm.
        signal.alarm(0)

        return self._get_mpr_result(process_statuses)
  def testDistribution(self, dist_name, data):
    if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
      return
    tf1.set_random_seed(
        data.draw(
            hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0)))
    dist = data.draw(distributions(dist_name=dist_name, enable_vars=True))
    batch_shape = dist.batch_shape
    batch_shape2 = data.draw(tfp_hps.broadcast_compatible_shape(batch_shape))
    dist2 = data.draw(
        distributions(
            dist_name=dist_name,
            batch_shape=batch_shape2,
            event_dim=get_event_dim(dist),
            enable_vars=True))
    logging.info(
        'distribution: %s; parameters used: %s', dist,
        [k for k, v in six.iteritems(dist.parameters) if v is not None])
    self.evaluate([var.initializer for var in dist.variables])

    # Check that the distribution passes Variables through to the accessor
    # properties (without converting them to Tensor or anything like that).
    for k, v in six.iteritems(dist.parameters):
      if not tensor_util.is_ref(v):
        continue
      self.assertIs(getattr(dist, k), v)

    # Check that standard statistics do not read distribution parameters more
    # than once.
    for stat in data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'covariance', 'entropy', 'mean', 'mode', 'stddev',
                    'variance'
                ])),
            min_size=3,
            max_size=3)):
      logging.info('%s.%s', dist_name, stat)
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'statistic `{}` of `{}`'.format(stat, dist)):
          getattr(dist, stat)()

      except NotImplementedError:
        pass

    # Check that `sample` doesn't read distribution parameters more than once,
    # and that it produces non-None gradients (if the distribution is fully
    # reparameterized).
    with tf.GradientTape() as tape:
      # TDs do bijector assertions twice (once by distribution.sample, and once
      # by bijector.forward).
      max_permissible = (
          3 if isinstance(dist, tfd.TransformedDistribution) else 2)
      with tfp_hps.assert_no_excessive_var_usage(
          'method `sample` of `{}`'.format(dist),
          max_permissible=max_permissible):
        sample = dist.sample()
    if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
      grads = tape.gradient(sample, dist.variables)
      for grad, var in zip(grads, dist.variables):
        var_name = var.name.rstrip('_0123456789:')
        if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()):
          continue
        if grad is None:
          raise AssertionError(
              'Missing sample -> {} grad for distribution {}'.format(
                  var_name, dist_name))

    # Turn off validations, since TODO(b/129271256) log_prob can choke on dist's
    # own samples.  Also, to relax conversion counts for KL (might do >2 w/
    # validate_args).
    dist = dist.copy(validate_args=False)
    dist2 = dist2.copy(validate_args=False)

    # Test that KL divergence reads distribution parameters at most once, and
    # that is produces non-None gradients.
    try:
      for d1, d2 in (dist, dist2), (dist2, dist):
        with tf.GradientTape() as tape:
          with tfp_hps.assert_no_excessive_var_usage(
              '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.format(
                  d1, d1.variables, d2, d2.variables),
              max_permissible=1):  # No validation => 1 convert per var.
            kl = d1.kl_divergence(d2)
        wrt_vars = list(d1.variables) + list(d2.variables)
        grads = tape.gradient(kl, wrt_vars)
        for grad, var in zip(grads, wrt_vars):
          if grad is None and dist_name not in NO_KL_PARAM_GRADS:
            raise AssertionError('Missing KL({} || {}) -> {} grad:\n'
                                 '{} vars: {}\n{} vars: {}'.format(
                                     d1, d2, var, d1, d1.variables, d2,
                                     d2.variables))
    except NotImplementedError:
      pass

    # Test that log_prob produces non-None gradients, except for distributions
    # on the NO_LOG_PROB_PARAM_GRADS blacklist.
    if dist_name not in NO_LOG_PROB_PARAM_GRADS:
      with tf.GradientTape() as tape:
        lp = dist.log_prob(tf.stop_gradient(sample))
      grads = tape.gradient(lp, dist.variables)
      for grad, var in zip(grads, dist.variables):
        if grad is None:
          raise AssertionError(
              'Missing log_prob -> {} grad for distribution {}'.format(
                  var, dist_name))

    # Test that all forms of probability evaluation avoid reading distribution
    # parameters more than once.
    for evaluative in data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'log_prob', 'prob', 'log_cdf', 'cdf',
                    'log_survival_function', 'survival_function'
                ])),
            min_size=3,
            max_size=3)):
      logging.info('%s.%s', dist_name, evaluative)
      try:
        # No validation => 1 convert. But for TD we allow 2:
        # dist.log_prob(bijector.inverse(samp)) + bijector.ildj(samp)
        max_permissible = (
            2 if isinstance(dist, tfd.TransformedDistribution) else 1)
        with tfp_hps.assert_no_excessive_var_usage(
            'evaluative `{}` of `{}`'.format(evaluative, dist),
            max_permissible=max_permissible):
          getattr(dist, evaluative)(sample)
      except NotImplementedError:
        pass
Ejemplo n.º 56
0
def main(unused_argv):
  logging.info("Loading %s", FLAGS.game_name)
  game = pyspiel.load_game(FLAGS.game_name)

  deep_cfr_solver = deep_cfr.DeepCFRSolver(
      game,
      policy_network_layers=(32, 32),
      advantage_network_layers=(16, 16),
      num_iterations=FLAGS.num_iterations,
      num_traversals=FLAGS.num_traversals,
      learning_rate=1e-3,
      batch_size_advantage=None,
      batch_size_strategy=None,
      memory_capacity=int(1e7))

  _, advantage_losses, policy_loss = deep_cfr_solver.solve()
  for player, losses in advantage_losses.items():
    logging.info("Advantage for player %d: %s", player,
                 losses[:2] + ["..."] + losses[-2:])
    logging.info("Advantage Buffer Size for player %s: '%s'", player,
                 len(deep_cfr_solver.advantage_buffers[player]))
  logging.info("Strategy Buffer Size: '%s'",
               len(deep_cfr_solver.strategy_buffer))
  logging.info("Final policy loss: '%s'", policy_loss)

  average_policy = policy.tabular_policy_from_callable(
      game, deep_cfr_solver.action_probabilities)
  pyspiel_policy = policy.python_policy_to_pyspiel_policy(average_policy)
  conv = pyspiel.nash_conv(game, pyspiel_policy)
  logging.info("Deep CFR in '%s' - NashConv: %s", FLAGS.game_name, conv)

  average_policy_values = expected_game_score.policy_value(
      game.new_initial_state(), [average_policy] * 2)
  logging.info("Computed player 0 value: %.2f (expected: %.2f).",
               average_policy_values[0], -1 / 18)
  logging.info("Computed player 1 value: %.2f (expected: %.2f).",
               average_policy_values[1], 1 / 18)
def transformed_distributions(draw,
                              batch_shape=None,
                              event_dim=None,
                              enable_vars=False,
                              depth=None):
  """Strategy for drawing `TransformedDistribution`s.

  The transforming bijector is drawn from the
  `bijectors.hypothesis_testlib.unconstrained_bijectors` strategy.

  The underlying distribution is drawn from the `distributions` strategy, except
  that it must be compatible with the bijector according to
  `bijectors.hypothesis_testlib.distribution_filter_for` (these generally check
  that vector bijectors are not combined with scalar distributions, etc).

  Args:
    draw: Hypothesis MacGuffin.  Supplied by `@hps.composite`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      `TransformedDistribution`.  The underlying distribution will sometimes
      have the same `batch_shape`, and sometimes have scalar batch shape.
      Hypothesis will pick a `batch_shape` if omitted.
    event_dim: Optional Python int giving the size of each of the underlying
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all Tensors, never Variables or DeferredTensor.
    depth: Python `int` giving maximum nesting depth of compound Distributions.

  Returns:
    dists: A strategy for drawing `TransformedDistribution`s with the specified
      `batch_shape` (or an arbitrary one if omitted).
  """
  if depth is None:
    depth = draw(depths())

  bijector = draw(bijector_hps.unconstrained_bijectors())
  logging.info('TD bijector: %s', bijector)
  if batch_shape is None:
    batch_shape = draw(tfp_hps.shapes())
  underlying_batch_shape = batch_shape
  batch_shape_arg = None
  if draw(hps.booleans()):
    # Use batch_shape overrides.
    underlying_batch_shape = tf.TensorShape([])  # scalar underlying batch
    batch_shape_arg = batch_shape
  underlyings = distributions(
      batch_shape=underlying_batch_shape,
      event_dim=event_dim,
      enable_vars=enable_vars,
      depth=depth - 1).filter(
          bijector_hps.distribution_filter_for(bijector))
  to_transform = draw(underlyings)
  logging.info(
      'TD underlying distribution: %s; parameters used: %s', to_transform,
      [k for k, v in six.iteritems(to_transform.parameters) if v is not None])
  # TODO(bjp): Add test coverage for `event_shape` argument of
  # `TransformedDistribution`.
  result_dist = tfd.TransformedDistribution(
      bijector=bijector,
      distribution=to_transform,
      batch_shape=batch_shape_arg,
      validate_args=True)
  if batch_shape != result_dist.batch_shape:
    msg = ('TransformedDistribution strategy generated a bad batch shape '
           'for {}, should have been {}.').format(result_dist, batch_shape)
    raise AssertionError(msg)
  return result_dist
Ejemplo n.º 58
0
    def _process_ns(self, ns):
        if self._filters:
            if ns.total_time > self._filters['max_total_time']:
                logging.info('Skipping %s: total_time=%f', ns.id,
                             ns.total_time)
                beam_metrics.counter('ExtractExamplesDoFn',
                                     'filtered-too-long').inc()
                return
            if len(ns.notes) > self._filters['max_num_notes']:
                logging.info('Skipping %s: num_notes=%d', ns.id, len(ns.notes))
                beam_metrics.counter('ExtractExamplesDoFn',
                                     'filtered-too-many-notes').inc()
                return

            try:
                qns = note_seq.quantize_note_sequence(ns, steps_per_quarter=16)
            except (note_seq.BadTimeSignatureError,
                    note_seq.NonIntegerStepsPerBarError,
                    note_seq.NegativeTimeError):
                beam_metrics.counter('ExtractExamplesDoFn',
                                     'quantize-failed').inc()
                return

            vels = set()
            metric_positions = set()
            drums_only = True
            for note in qns.notes:
                drums_only &= note.is_drum
                if ((self._filters['is_drum'] is None
                     or note.is_drum == self._filters['is_drum'])
                        and note.velocity > 0):
                    vels.add(note.velocity)
                    metric_positions.add(note.quantized_start_step % 16)

            if len(vels) < self._filters['min_velocities']:
                beam_metrics.counter('ExtractExamplesDoFn',
                                     'filtered-min-velocities').inc()
                return
            if len(metric_positions) < self._filters['min_metric_positions']:
                beam_metrics.counter('ExtractExamplesDoFn',
                                     'filtered-min-metric-positions').inc()
                return
            if self._filters['drums_only'] and not drums_only:
                beam_metrics.counter('ExtractExamplesDoFn',
                                     'filtered-drums-only').inc()
                return

        beam_metrics.counter('ExtractExamplesDoFn',
                             'unfiltered-sequences').inc()
        logging.info('Converting %s to tensors', ns.id)
        extracted_examples = self._config.data_converter.to_tensors(ns)
        if not extracted_examples.outputs:
            beam_metrics.counter('ExtractExamplesDoFn',
                                 'empty-extractions').inc()
            return
        beam_metrics.counter('ExtractExamplesDoFn', 'extracted-examples').inc(
            len(extracted_examples.outputs))
        for _, outputs, controls, _ in zip(*extracted_examples):
            if controls.size:
                example_ns = self._config.data_converter.from_tensors(
                    [outputs], [controls])[0]
            else:
                example_ns = self._config.data_converter.from_tensors(
                    [outputs])[0]
            # Try to re-encode.
            # TODO(adarob): For now we filter and count examples that cannot be
            # re-extracted, but ultimately the converter should filter these or avoid
            # producing them all together.
            reextracted_examples = self._config.data_converter.to_tensors(
                example_ns).inputs
            assert len(reextracted_examples) <= 1
            if not reextracted_examples:
                logging.warning(
                    'Extracted example NoteSequence does not reproduce example. '
                    'Skipping: %s', example_ns)
                beam_metrics.counter('ExtractExamplesDoFn',
                                     'empty-reextraction').inc()
                continue
            # Extra checks if the code returns multiple segments.
            # TODO(fjord): should probably make this recursive for cases with more
            # than 1 level of hierarchy.
            if isinstance(outputs, list):
                if len(outputs) != len(reextracted_examples[0]):
                    logging.warning(
                        'Re-extracted example tensor has different number of segments. '
                        'ID: %s. original %d, reextracted %d. Skipping.',
                        ns.id, len(outputs), len(reextracted_examples[0]))
                    beam_metrics.counter('ExtractExamplesDoFn',
                                         'different-reextraction-count').inc()
                    continue
                for i in range(len(outputs)):
                    if not np.array_equal(reextracted_examples[0][i],
                                          outputs[i]):
                        logging.warning(
                            'Re-extracted example tensor does not equal original example. '
                            'ID: %s. Index %d. NoteSequence: %s', ns.id, i,
                            example_ns)
                        beam_metrics.counter('ExtractExamplesDoFn',
                                             'different-reextraction').inc()
            yield example_ns, ns.id
Ejemplo n.º 59
0
def run(iterative_process: tff.templates.IterativeProcess,
        client_datasets_fn: Callable[[int], List[tf.data.Dataset]],
        validation_fn: Callable[[Any, int], Dict[str, float]],
        total_rounds: int,
        experiment_name: str,
        test_fn: Optional[Callable[[Any], Dict[str, float]]] = None,
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        rounds_per_eval: Optional[int] = 1,
        rounds_per_checkpoint: Optional[int] = 50,
        rounds_per_profile: Optional[int] = 0):
    """Runs federated training for a given `tff.templates.IterativeProcess`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  `tff.learning.ModelWeights` object.

  Args:
    iterative_process: A `tff.templates.IterativeProcess` instance to run.
    client_datasets_fn: Function accepting an integer argument (the round
      number) and returning a list of client datasets to use as federated data
      for that round.
    validation_fn: A callable accepting a `tff.learning.ModelWeights` and the
      current round number, and returning a dict of evaluation metrics. Used to
      compute validation metrics throughout the training process.
    total_rounds: The number of federated training rounds to perform.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    test_fn: An optional callable accepting a `tff.learning.ModelWeights` and
      returning a dict of test set metrics. Used to compute test metrics at the
      end of the training process.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    rounds_per_eval: How often to compute validation metrics.
    rounds_per_checkpoint: How often to checkpoint the iterative process state.
      If you expect the job to restart frequently, this should be small. If no
      interruptions are expected, this can be made larger.
    rounds_per_profile: Experimental setting. If set to a value greater than 0,
      this dictates how often a TensorFlow profiler is run.

  Returns:
    The final `state` of the iterative process after training.
  """
    _check_iterative_process_compatibility(iterative_process)
    if not callable(client_datasets_fn):
        raise TypeError('client_datasets_fn should be callable.')
    if not callable(validation_fn):
        raise TypeError('validation_fn should be callable.')
    if test_fn is not None and not callable(test_fn):
        raise TypeError('test_fn should be callable.')

    logging.info('Starting iterative_process training loop...')
    initial_state = iterative_process.initialize()

    checkpoint_mngr, metrics_mngr, tb_mngr, profiler = _setup_outputs(
        root_output_dir, experiment_name, rounds_per_profile)

    logging.info('Asking checkpoint manager to load checkpoint.')
    state, round_num = checkpoint_mngr.load_latest_checkpoint(initial_state)

    if state is None:
        logging.info('Initializing experiment from scratch.')
        state = initial_state
        round_num = 0
        metrics_mngr.clear_all_rounds()
    else:
        logging.info('Restarted from checkpoint round %d', round_num)
        round_num += 1  # Increment to avoid overwriting current checkpoint
        metrics_mngr.clear_rounds_after(last_valid_round_num=round_num - 1)

    current_model = iterative_process.get_model_weights(state)

    loop_start_time = time.time()
    loop_start_round = round_num
    while round_num < total_rounds:
        data_prep_start_time = time.time()
        federated_train_data = client_datasets_fn(round_num)
        train_metrics = {
            'prepare_datasets_secs': time.time() - data_prep_start_time
        }

        training_start_time = time.time()
        prev_model = current_model

        # TODO(b/145604851): This try/except is used to circumvent ambiguous TF
        # errors during training, and should be removed once the root cause is
        # determined (and possibly fixed).
        try:
            with profiler(round_num):
                state, round_metrics = iterative_process.next(
                    state, federated_train_data)
        except (tf.errors.FailedPreconditionError, tf.errors.NotFoundError,
                tf.errors.InternalError) as e:
            logging.warning(
                'Caught %s exception while running round %d:\n\t%s', type(e),
                round_num, e)
            continue  # restart the loop without incrementing the round number

        current_model = iterative_process.get_model_weights(state)
        train_metrics['training_secs'] = time.time() - training_start_time
        train_metrics['model_delta_l2_norm'] = _compute_numpy_l2_difference(
            current_model, prev_model)
        train_metrics.update(round_metrics)

        loop_time = time.time() - loop_start_time
        loop_rounds = (round_num - loop_start_round + 1)
        logging.info('Round {:2d}, {:.2f}s per round in average.'.format(
            round_num, loop_time / loop_rounds))

        if (round_num % rounds_per_checkpoint == 0
                or round_num == total_rounds - 1):
            save_checkpoint_start_time = time.time()
            checkpoint_mngr.save_checkpoint(state, round_num)
            train_metrics['save_checkpoint_secs'] = (
                time.time() - save_checkpoint_start_time)

        metrics = {'train': train_metrics}

        if round_num % rounds_per_eval == 0:
            # Compute validation metrics
            evaluate_start_time = time.time()
            validation_metrics = validation_fn(current_model, round_num)
            validation_metrics['evaluate_secs'] = time.time(
            ) - evaluate_start_time
            metrics['eval'] = validation_metrics

        _write_metrics(metrics_mngr, tb_mngr, metrics, round_num)
        round_num += 1

    # Final metrics evaluation once the training has completed
    metrics = {}

    # Validation metrics
    evaluate_start_time = time.time()
    validation_metrics = validation_fn(current_model, round_num)
    validation_metrics['evaluate_secs'] = time.time() - evaluate_start_time
    metrics['eval'] = validation_metrics

    # Test set metrics
    if test_fn:
        test_start_time = time.time()
        test_metrics = test_fn(current_model)
        test_metrics['evaluate_secs'] = time.time() - test_start_time
        metrics['test'] = test_metrics
    _write_metrics(metrics_mngr, tb_mngr, metrics, total_rounds)

    return state
Ejemplo n.º 60
0
def build_optimizer(
        optimizer_name: Text,
        base_learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
        params: Dict[Text, Any]):
    """Build the optimizer based on name.

  Args:
    optimizer_name: String representation of the optimizer name. Examples:
      sgd, momentum, rmsprop.
    base_learning_rate: `tf.keras.optimizers.schedules.LearningRateSchedule`
      base learning rate.
    params: String -> Any dictionary representing the optimizer params.
      This should contain optimizer specific parameters such as
      `base_learning_rate`, `decay`, etc.

  Returns:
    A tf.keras.Optimizer.

  Raises:
    ValueError if the provided optimizer_name is not supported.

  """
    optimizer_name = optimizer_name.lower()
    logging.info('Building %s optimizer with params %s', optimizer_name,
                 params)

    if optimizer_name == 'sgd':
        logging.info('Using SGD optimizer')
        nesterov = params.get('nesterov', False)
        optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate,
                                            nesterov=nesterov)
    elif optimizer_name == 'momentum':
        logging.info('Using momentum optimizer')
        nesterov = params.get('nesterov', False)
        optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate,
                                            momentum=params['momentum'],
                                            nesterov=nesterov)
    elif optimizer_name == 'rmsprop':
        logging.info('Using RMSProp')
        rho = params.get('decay', None) or params.get('rho', 0.9)
        momentum = params.get('momentum', 0.9)
        epsilon = params.get('epsilon', 1e-07)
        optimizer = tf.keras.optimizers.RMSprop(
            learning_rate=base_learning_rate,
            rho=rho,
            momentum=momentum,
            epsilon=epsilon)
    elif optimizer_name == 'adam':
        logging.info('Using Adam')
        beta_1 = params.get('beta_1', 0.9)
        beta_2 = params.get('beta_2', 0.999)
        epsilon = params.get('epsilon', 1e-07)
        optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate,
                                             beta_1=beta_1,
                                             beta_2=beta_2,
                                             epsilon=epsilon)
    elif optimizer_name == 'adamw':
        logging.info('Using AdamW')
        weight_decay = params.get('weight_decay', 0.01)
        beta_1 = params.get('beta_1', 0.9)
        beta_2 = params.get('beta_2', 0.999)
        epsilon = params.get('epsilon', 1e-07)
        optimizer = tfa.optimizers.AdamW(weight_decay=weight_decay,
                                         learning_rate=base_learning_rate,
                                         beta_1=beta_1,
                                         beta_2=beta_2,
                                         epsilon=epsilon)
    else:
        raise ValueError('Unknown optimizer %s' % optimizer_name)

    moving_average_decay = params.get('moving_average_decay', 0.)
    if moving_average_decay is not None and moving_average_decay > 0.:
        logging.info('Including moving average decay.')
        optimizer = tfa.optimizers.MovingAverage(
            optimizer,
            average_decay=params['moving_average_decay'],
            num_updates=None)
    if params.get('lookahead', None):
        logging.info('Using lookahead optimizer.')
        optimizer = tfa.optimizers.Lookahead(optimizer)
    return optimizer