示例#1
0
def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'):
    """Restore last/best checkpoint from checkpoints in path.

  Sorts the checkpoint files naturally, returning the highest-valued
  file, e.g.:
    ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
    ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
    ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

  Args:
    ckpt_dir: str: directory of checkpoints to restore from.
    target: matching object to rebuild via deserialized state-dict.
    step: int: step number to load or None to load latest.
    prefix: str: name prefix of checkpoint files.

  Returns:
    Restored `target` updated from checkpoint file, or if no step specified and
    no checkpoint files present, returns the passed-in `target` unchanged.
  """
    if step:
        ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
        if not gfile.exists(ckpt_path):
            raise ValueError(f'Matching checkpoint not found: {ckpt_path}')
    else:
        glob_path = os.path.join(ckpt_dir, f'{prefix}*')
        checkpoint_files = natural_sort(gfile.glob(glob_path))
        ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
        checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path]
        if not checkpoint_files:
            return target
        ckpt_path = checkpoint_files[-1]

    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        return serialization.from_bytes(target, fp.read())
示例#2
0
    def __init__(self, path, dataset, *args, **kwargs):
        self.dataset = dataset
        self.vocab = Vocab(*args, **kwargs)

        if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
            self.vocab.count_file(os.path.join(path, "train.txt"))
            self.vocab.count_file(os.path.join(path, "valid.txt"))
            self.vocab.count_file(os.path.join(path, "test.txt"))
        elif self.dataset == "wt103":
            self.vocab.count_file(os.path.join(path, "train.txt"))
        elif self.dataset == "lm1b":
            train_path_pattern = os.path.join(
                path, "1-billion-word-language-modeling-benchmark-r13output",
                "training-monolingual.tokenized.shuffled", "news.en-*")
            train_paths = glob(train_path_pattern)

            # the vocab will load from file when build_vocab() is called
            # for train_path in sorted(train_paths):
            #   self.vocab.count_file(train_path, verbose=True)

        self.vocab.build_vocab()

        if self.dataset in ["ptb", "wt2", "wt103"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True)
        elif self.dataset in ["enwik8", "text8"]:
            self.train = self.vocab.encode_file(os.path.join(
                path, "train.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.valid = self.vocab.encode_file(os.path.join(
                path, "valid.txt"),
                                                ordered=True,
                                                add_eos=False)
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"),
                                               ordered=True,
                                               add_eos=False)
        elif self.dataset == "lm1b":
            self.train = train_paths
            valid_path = os.path.join(path, "valid.txt")
            test_path = valid_path
            self.valid = self.vocab.encode_file(valid_path,
                                                ordered=True,
                                                add_double_eos=True)
            self.test = self.vocab.encode_file(test_path,
                                               ordered=True,
                                               add_double_eos=True)

        if self.dataset == "wt103":
            self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)]
        elif self.dataset == "lm1b":
            self.cutoffs = [0, 60000, 100000, 640000] + [len(self.vocab)]
        else:
            self.cutoffs = []
示例#3
0
def maybe_restore_params(output_dir, policy_and_value_net_params, state):
    """Maybe restore the params from the checkpoint dir.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_net_params: Default params, returned if model is'nt found.
    state: policy state.

  Returns:
    tuple (restore (bool), params, state, iter (int), opt_step (int)) where iter
    is the epoch from which we restored the params, 0 is restore = False, and
    opt_step is the total optimization step (sum of all optimization steps made
    up to the current epoch).
  """
    model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl"))
    for model_file in reversed(sorted(model_files)):
        logging.info("Trying to restore model from %s", model_file)
        try:
            with gfile.GFile(model_file, "rb") as f:
                loaded_policy_and_value_net_params, loaded_state, total_opt_step = (
                    pickle.load(f))
                policy_and_value_net_params = loaded_policy_and_value_net_params
                state = loaded_state
            model_file_basename = os.path.basename(
                model_file)  # model-??????.pkl
            i = int(filter(str.isdigit, model_file_basename))
            return True, policy_and_value_net_params, state, i, total_opt_step
        except EOFError as e:
            logging.error("Unable to load model from: %s with %s", model_file,
                          e)
            # Try an older version.
            continue
    return False, policy_and_value_net_params, state, 0, 0
    def load_weights(self, base_fn):
        """Find the latest checkpoint matching base_fn, and load the weights."""

        matcher = base_fn + "_*.npy"
        filenames = sorted(gfile.glob(matcher), reverse=True)
        assert len(filenames) > 0, "No files matching {}".format(matcher)
        filename = filenames[0]

        # load array
        with gfile.GFile(filename, "rb") as fin:
            serialized_weights = np.load(fin)

        print(serialized_weights.shape, self.all_weights_flat_sizes)
        all_weights_flat_split = tf.split(serialized_weights,
                                          self.all_weights_flat_sizes)
        all_weights_flat = [
            tf.reshape(t, s) for t, s in zip(all_weights_flat_split,
                                             self.all_weights_flat_shapes)
        ]

        all_weights = tf.nest.pack_sequence_as(self.all_weights_structure,
                                               all_weights_flat)

        all_layers = self.layers + [self.loss_layer]
        if self.shared_params is not None:
            all_layers += list(self.shared_params.values())
        for l, lw in zip(all_layers, all_weights):
            l.load_weights(lw)
示例#5
0
def maybe_restore_params(output_dir, policy_and_value_net_params):
    """Maybe restore the params from the checkpoint dir.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_net_params: Default params, returned if model is'nt found.

  Returns:
    triple (restore (bool), params, iter(int)) where iter is the epoch from
    which we restored the params, 0 is restore = False.
  """
    model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl"))
    for model_file in reversed(sorted(model_files)):
        logging.info("Trying to restore model from %s", model_file)
        try:
            with gfile.GFile(model_file, "rb") as f:
                loaded_policy_and_value_net_params = pickle.load(f)
                policy_and_value_net_params = loaded_policy_and_value_net_params
            model_file_basename = os.path.basename(
                model_file)  # model-??????.pkl
            i = int(filter(str.isdigit, model_file_basename))
            return True, policy_and_value_net_params, i
        except EOFError as e:
            logging.error("Unable to load model from: %s with %s", model_file,
                          e)
            # Try an older version.
            continue
    return False, policy_and_value_net_params, 0
示例#6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    logging.get_absl_handler().use_absl_log_file()

    logging.info('Opening %s', FLAGS.output_sqlite)
    db = smu_sqlite.SMUSQLite(FLAGS.output_sqlite, 'c')

    if FLAGS.bond_topology_csv:
        logging.info('Starting smiles to btid inserts')
        smiles_id_dict = smu_utils_lib.smiles_id_dict_from_csv(
            open(FLAGS.bond_topology_csv))
        db.bulk_insert_smiles(smiles_id_dict.items())
        logging.info('Finished smiles to btid inserts')
    else:
        logging.info('Skipping smiles inserts')

    logging.info('Starting main inserts')
    dataset = tf.data.TFRecordDataset(gfile.glob(FLAGS.input_tfrecord))
    db.bulk_insert((raw.numpy() for raw in dataset), batch_size=10000)

    logging.info('Starting vacuuming')
    db.vacuum()
    logging.info('Vacuuming finished')
示例#7
0
def iterate_checkpoints(checkpoint_dir, min_global_step, max_global_step):
  """Iterates over all checkpoints in the interval [lb, ub)."""
  for checkpoint_path in gfile.glob(os.path.join(checkpoint_dir, 'ckpt_*')):

    step = int(checkpoint_path.split('_')[-1])
    if min_global_step is None or (min_global_step <= step < max_global_step):
      full_path = os.path.join(checkpoint_dir, checkpoint_path)
      yield full_path, step
示例#8
0
def merge_shared_tsvs(filename):
  """Merge multiple tsv files into one."""
  output_files = gfile.glob("%s-*-of-*" % filename)
  all_examples = []
  for output_file in output_files:
    examples = read_tsv(output_file)
    all_examples.extend(examples)
  write_tsv(all_examples, filename)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    logging.info('Opening %s', FLAGS.output_sqlite)
    db = smu_sqlite.SMUSQLite(FLAGS.output_sqlite, 'c')

    dataset = tf.data.TFRecordDataset(gfile.glob(FLAGS.input_tfrecord))
    db.bulk_insert((raw.numpy() for raw in dataset), batch_size=10000)
示例#10
0
def get_checkpoints_in_range(checkpoint_dir, lower_bound, upper_bound):
  """Get checkpoint paths in step range [lower_bound, upper_bound]."""
  checkpoint_paths = []
  for checkpoint_path in gfile.glob(
      os.path.join(checkpoint_dir, 'ckpt_*')):
    ckpt_step = int(checkpoint_path.split('_')[-1])
    if ckpt_step >= lower_bound and ckpt_step <= upper_bound:
      checkpoint_paths.append(os.path.join(checkpoint_dir, checkpoint_path))
  return checkpoint_paths
示例#11
0
    def test_delete_old_checkpoints(self):
        """Test that old checkpoints are deleted."""
        state1 = dict(params=self.params, global_step=5, completed_epochs=4)
        checkpoint.save_checkpoint(self.test_dir, 0, state1, max_to_keep=1)

        state2 = dict(params=self.params, global_step=10, completed_epochs=8)
        checkpoint.save_checkpoint(self.test_dir, 1, state2, max_to_keep=1)
        dir_contents = gfile.glob(os.path.join(self.test_dir, '*'))
        self.assertLen(dir_contents, 1)
示例#12
0
def write_backgrounds_csv(new_dataset_dir, backgrounds_dir):
    """Write CSV mapping background ints to bg filenames."""
    bg_filenames = gfile.glob(path.join(backgrounds_dir, '*'))
    bg_filenames = [fname.split('/')[-1] for fname in bg_filenames]
    csv_filepath = path.join(new_dataset_dir, 'backgrounds.csv')
    with open(csv_filepath, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['int', 'label'])

        for i, fname in enumerate(bg_filenames):
            writer.writerow([i, fname])
示例#13
0
def merge_predictions(examples, filename):
    """Merge multiple predcition files into one."""
    source_to_prediction = {}
    output_files = gfile.glob("%s-*-of-*" % filename)
    for output_file in output_files:
        predictions = tsv_utils.read_tsv(output_file)
        for prediction in predictions:
            source, predicted_target = prediction
            source_to_prediction[source] = predicted_target
    new_predictions = []
    for example in examples:
        new_predictions.append((source_to_prediction[example[0]]))
    txt_utils.write_txt(new_predictions, filename)
示例#14
0
 def save(self):
   """Save the agent parameters."""
   logging.vlog(1, "Epoch [% 6d] saving model.", self._epoch)
   old_model_files = gfile.glob(
       os.path.join(self._output_dir, "model-??????.pkl"))
   params_file = os.path.join(self._output_dir, "model-%06d.pkl" % self._epoch)
   with gfile.GFile(params_file, "wb") as f:
     pickle.dump(self._policy_and_value_net_params, f)
   # Remove the old model files.
   for path in old_model_files:
     gfile.remove(path)
   # Reset this number.
   self._n_trajectories_done = 0
   self._last_saved_at = self._epoch
示例#15
0
    def _load_foregrounds(self, foregrounds_dir):
        """Loads foregrounds from a directory.

    Args:
      foregrounds_dir: path to directory containing foregrounds.
        Directory of the form `foregrounds_dir`/$OBJECT_CLASS/$FILE_NAME.

    Produces:
      self.fg_classes: a list of names of foreground object classes, e.g.
        ['ambulance', 'bagel', ...]
      self.num_fgs_per_class: a dict of the form {foreground_obj_class_name:
        num_fgs_in_that_class}
      self.fgs: a list of the form [fg0, fg1, ...] where the foregrounds are
        `PIL.PngImagePlugin.PngImageFile`s.
      self.fgs_dict: a dict of the form {fg_class_name: [img0, img1, ...]} where
        the images are `PIL.PngImagePlugin.PngImageFile`s.
    """
        if not gfile.exists(foregrounds_dir):
            raise ValueError(
                f'Foregrounds directory {foregrounds_dir} does not exist.')
        fg_fnames = gfile.glob(path.join(foregrounds_dir, '*/*'))
        fg_labels = [x.split('/')[-2] for x in fg_fnames]  # e.g. 'car', 'cow'
        self.fg_classes = sorted(list(set(fg_labels)))
        self.num_fgs_per_class = {
            fg_class: len(gfile.glob(path.join(foregrounds_dir, fg_class,
                                               '*')))
            for fg_class in self.fg_classes
        }
        self.num_fgs_per_class_list = [
            self.num_fgs_per_class[fg_class] for fg_class in self.fg_classes
        ]
        self.fgs = self._thread_pool.map(load_image, fg_fnames)
        self.fgs_dict = {fg_class: [] for fg_class in self.fg_classes}
        for i, label in enumerate(fg_labels):
            self.fgs_dict[label].append(self.fgs[i])

        print('Foregrounds loaded.')
示例#16
0
 def iterate_checkpoints(self):
     """Iterates over all checkpoints."""
     if self.ckpt_to_evaluate:
         step = int(self.ckpt_to_evaluate.split('_')[-1])
         full_path = os.path.join(self.checkpoint_dir,
                                  self.ckpt_to_evaluate)
         yield full_path, step
     else:
         for checkpoint_path in gfile.glob(
                 os.path.join(self.checkpoint_dir, 'ckpt_*')):
             step = int(checkpoint_path.split('_')[-1])
             if self.min_step and step < int(self.min_step):
                 continue
             full_path = os.path.join(self.checkpoint_dir, checkpoint_path)
             yield full_path, step
def read_dir():
    if not os.path.isdir(SOUNDS_DIR):
        raise Exception('Sound directory with name \'' + SOUNDS_DIR + '\' not found!')

    data = []

    for word in WANTED_WORDS:
        word_dir = SOUNDS_DIR + word
        if not os.path.isdir(word_dir):
            raise Exception('Sounds directory for \'' + word + '\' not found at ' + word_dir + '!')

        search_path = os.path.join(word_dir, '*.wav')
        for wav_path in gfile.glob(search_path):
            data.append({'word': word, 'file': wav_path})

    return data
示例#18
0
  def __init__(self, path, add_eos, add_beos, *args, **kwargs):
    self.vocab = Vocab(*args, **kwargs)
    self.add_eos = add_eos
    self.add_beos = add_beos
    self.vocab.build_vocab()
    pattern = os.path.join(path, "train", "train.??")
    self.train = glob(pattern)

    self.valid = \
      self.vocab.encode_file(os.path.join(path, "valid.txt"),
                             ordered=True, add_eos=self.add_eos,
                             add_beos=self.add_beos)
    self.test = \
      self.vocab.encode_file(os.path.join(path, "test.txt"),
                             ordered=True, add_eos=self.add_eos,
                             add_beos=self.add_beos)
示例#19
0
    def load_from_directory(trajectory_dir, epoch=None, n_trajectories=None):
        """Load trajectories from specified dir and epoch."""

        trajectory_file_glob = TRAJECTORY_FILE_GLOB

        # If there is a desired epoch, modify the glob to get that instead.
        if epoch:
            trajectory_file_glob = trajectory_file_glob.replace(
                "epoch_*", "epoch_%06d" % epoch)

        trajectory_files = gfile.glob(
            os.path.join(trajectory_dir, trajectory_file_glob))

        if not trajectory_files:
            return None

        # We read and load all the files, revisit if this becomes a problem.
        trajectories_buffer = []
        completed_trajectories_buffer = []
        for trajectory_file in trajectory_files:
            with gfile.GFile(trajectory_file, "rb") as f:
                list_trajectories = pickle.load(f)
                assert isinstance(list_trajectories, list)
                if not list_trajectories:
                    continue
                assert isinstance(list_trajectories[0], Trajectory)
                for trajectory in list_trajectories:
                    if trajectory.done:
                        completed_trajectories_buffer.append(trajectory)
                    else:
                        trajectories_buffer.append(trajectory)

        if not trajectories_buffer and not completed_trajectories_buffer:
            return None

        # Randomly sample `n_trajectories` if needed.
        n_trajectories = None if not n_trajectories else int(n_trajectories)
        if n_trajectories and n_trajectories > 0:
            trajectories_buffer = list(
                np.random.choice(trajectories_buffer,
                                 int(trajectories_buffer)))

        # Construct and return a new BatchTrajectory object.
        return BatchTrajectory(
            batch_size=len(trajectories_buffer),
            trajectories=trajectories_buffer,
            completed_trajectories=completed_trajectories_buffer)
示例#20
0
 def _get_tfrecords(self, name):
     paths = self.params.data_dir.split(':')
     data_dir = None
     for path in paths:
         if gfile.exists(join(path, name)):
             data_dir = path
             break
     assert data_dir is not None, "data_dir not found"
     paths = list(
         map(lambda x: join(data_dir, name, x),
             self.params.data_pattern.split(',')))
     files = gfile.glob(paths)
     if not files:
         raise IOError("Unable to find files. data_pattern='{}'.".format(
             self.params.data_pattern))
     logging.info("Number of TFRecord files: {}.".format(len(files)))
     return files
示例#21
0
def maybe_restore_opt_state(output_dir, policy_and_value_opt_state,
                            policy_and_value_state):
    """Maybe restore the optimization state from the checkpoint dir.

  Optimization state includes parameters and optimizer slots.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_opt_state: Default optimization state, returned if model
      isn't found.
    policy_and_value_state: state of the policy and value network.

  Returns:
    tuple (restored (bool), opt_state, state, epoch (int),
    opt_step (int)) where epoch is the epoch from which we restored the
    optimization state, 0 is restored = False, and opt_step is the total
    optimization step (sum of all optimization steps made up to the current
    epoch).
  """
    restored = False
    epoch = 0
    total_opt_step = 0
    model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl"))
    for model_file in reversed(sorted(model_files)):
        logging.info("Trying to restore model from %s", model_file)
        try:
            with gfile.GFile(model_file, "rb") as f:
                policy_and_value_opt_state, policy_and_value_state, total_opt_step = (
                    pickle.load(f))
            model_file_basename = os.path.basename(
                model_file)  # model-??????.pkl
            restored = True
            epoch = int(filter(str.isdigit, model_file_basename))
            break
        except EOFError as e:
            logging.error("Unable to load model from: %s with %s", model_file,
                          e)
            # Try an older version.
            continue
    return (
        restored,
        policy_and_value_opt_state,
        policy_and_value_state,
        epoch,
        total_opt_step,
    )
示例#22
0
def latest_checkpoint(ckpt_dir, prefix='checkpoint_'):
    """Retrieve the path of the latest checkpoint in a directory.

  Args:
    ckpt_dir: str: directory of checkpoints to restore from.
    prefix: str: name prefix of checkpoint files.

  Returns:
    The latest checkpoint path or None if no checkpoints were found.
  """
    glob_path = os.path.join(ckpt_dir, f'{prefix}*')
    checkpoint_files = natural_sort(gfile.glob(glob_path))
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path]
    if checkpoint_files:
        return checkpoint_files[-1]
    else:
        return None
    def to_dataframe(self, split: str, max_rows: int = 100) -> pd.DataFrame:
        """Returns dataframe representation of the artifact.

    Args:
      split: The name of the datasplit to be returned.
      max_rows: The maximum number of rows to be returned. If None, all rows in
        the split will be returned.
    """

        self._validate_payload()

        if max_rows and max_rows < 0:
            raise ValueError('`max_rows` must not be negative. Got: %d' %
                             max_rows)
        filepaths = gfile.glob(os.path.join(self.uri, split, '*'))
        ds = tf.data.TFRecordDataset(filepaths,
                                     compression_type='GZIP').take(max_rows)
        return self._load_table(ds)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    atomic_writer = smu_writer_lib.AtomicInputWriter()

    file_count = 0
    conformer_count = 0
    mismatches = 0

    for filepath in gfile.glob(FLAGS.input_glob):
        logging.info('Processing file %s', filepath)
        file_count += 1
        smu_parser = smu_parser_lib.SmuParser(filepath)
        for conformer, _ in smu_parser.process_stage2():
            conformer_count += 1

            actual_contents = atomic_writer.process(conformer)

            expected_fn = atomic_writer.get_filename_for_atomic_input(
                conformer)
            with gfile.GFile(os.path.join(FLAGS.atomic_input_dir,
                                          expected_fn)) as expected_f:
                expected_contents = expected_f.readlines()

            try:
                smu_writer_lib.check_dat_formats_match(
                    expected_contents, actual_contents.splitlines())
            except smu_writer_lib.DatFormatMismatchError as e:
                mismatches += 1
                print(e)
                if FLAGS.output_dir:
                    with gfile.GFile(
                            os.path.join(
                                FLAGS.output_dir,
                                atomic_writer.get_filename_for_atomic_input(
                                    conformer)), 'w') as f:
                        f.write(actual_contents)

    status_str = ('COMPLETE: Read %d files, %d conformers, %d mismatches\n' %
                  (file_count, conformer_count, mismatches))

    logging.info(status_str)
    print(status_str)
示例#25
0
def save_checkpoint(ckpt_dir,
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1):
  """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str: path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.

  Returns:
    Filename of saved checkpoint.
  """
  # Write temporary checkpoint file.
  logging.info('Saving checkpoint at step: %s', step)
  ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
  ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
  gfile.makedirs(os.path.dirname(ckpt_path))
  with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
    fp.write(serialization.to_bytes(target))

  # Rename once serialization and writing finished.
  gfile.rename(ckpt_tmp_path, ckpt_path)
  logging.info('Saved checkpoint at %s', ckpt_path)

  # Remove old checkpoint files.
  base_path = os.path.join(ckpt_dir, f'{prefix}')
  checkpoint_files = natural_sort(gfile.glob(base_path + '*'))
  if len(checkpoint_files) > keep:
    old_ckpts = checkpoint_files[:-keep]
    for path in old_ckpts:
      logging.info('Removing checkpoint at %s', path)
      gfile.remove(path)

  return ckpt_path
示例#26
0
def tensorboard_event_to_dataframe(path: str) -> pd.DataFrame:
  """Helper to get events written by tests.

  Args:
    path: Path where the tensorboard records were saved.

  Returns:
    The metric saved by tensorboard, as a dataframe.
  """
  records = []
  all_tb_path = gfile.glob(os.path.join(path, 'events.*.v2'))
  for tb_event_path in all_tb_path:
    for e in tf.compat.v1.train.summary_iterator(tb_event_path):
      if e.step:
        for v in e.summary.value:
          records.append(dict(
              step=e.step, metric=v.tag,
              value=float(tf.make_ndarray(v.tensor))))
  df = pd.DataFrame.from_records(records)
  return df
示例#27
0
    def _load_backgrounds(self, backgrounds_dir):
        """Loads backgrounds from a directory.

    Args:
      backgrounds_dir: path to directory containing foregrounds.
        Dir of the form `backrounds_dir`/$BACKGROUND_TYPE/$FILE_NAME.

    Produces:
      self.bgs: a list of the form [bg0, bg1, ...] where the backgrounds
        are `PIL.Image.Image`s.
      self.num_bgs: int, number of backgrounds.
    """
        if not gfile.exists(backgrounds_dir):
            raise ValueError(
                f'Backgrounds directory {backgrounds_dir} does not exist.')
        bg_fnames = gfile.glob(path.join(backgrounds_dir, '*'))
        self.bgs = self._thread_pool.map(load_image, bg_fnames)
        self.bgs = self._thread_pool.map(self._preprocess_background, self.bgs)
        self.num_bgs = len(self.bgs)

        print('Backgrounds loaded.')
示例#28
0
def maybe_restore_params(output_dir, policy_and_value_net_params):
    """Maybe restore the params from the checkpoint dir.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_net_params: Default params, returned if model is'nt found.

  Returns:
    triple (restore (bool), params, iter(int)) where iter is the epoch from
    which we restored the params, 0 is restore = False.
  """
    model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl"))
    if not model_files:
        return False, policy_and_value_net_params, 0

    model_file = sorted(model_files)[-1]
    model_file_basename = os.path.basename(model_file)  # model-??????.pkl
    i = int(filter(str.isdigit, model_file_basename))
    with gfile.GFile(model_file, "rb") as f:
        policy_and_value_net_params = pickle.load(f)
    return True, policy_and_value_net_params, i
示例#29
0
    def test_whole_pipeline(self):
        test_subdirectory = self.create_tempdir()
        output_stem = os.path.join(test_subdirectory, 'testout')
        input_stage1_dat_glob = os.path.join(TESTDATA_PATH,
                                             'pipeline_input_stage1.dat')
        input_stage2_dat_glob = os.path.join(TESTDATA_PATH,
                                             'pipeline_input_stage2.dat')
        input_equivalent_glob = os.path.join(TESTDATA_PATH,
                                             'pipeline_equivalent.dat')
        input_bond_topology_csv = os.path.join(TESTDATA_PATH,
                                               'pipeline_bond_topology.csv')
        with flagsaver.flagsaver(
                input_stage1_dat_glob=input_stage1_dat_glob,
                input_stage2_dat_glob=input_stage2_dat_glob,
                input_equivalent_glob=input_equivalent_glob,
                input_bond_topology_csv=input_bond_topology_csv,
                output_stem=output_stem,
                output_shards=1):
            # If you have custom beam options, add them here.
            beam_options = None
            with beam.Pipeline(beam_options) as root:
                pipeline.pipeline(root)

        metrics = root.result.metrics().query()
        counters_dict = {
            m.key.metric.name: m.committed
            for m in metrics['counters']
        }

        self.assertEqual(counters_dict['attempted_topology_matches'], 3)
        # Conformer 620517 will not match because bond lengths are not extracted
        # from conformers with serious errors like this.
        self.assertEqual(counters_dict['no_topology_matches'], 1)
        self.assertNotIn('topology_match_smiles_failure', counters_dict)

        logging.info(
            'Files in output: %s',
            '\n'.join(gfile.glob(os.path.join(test_subdirectory, '*'))))
        for stage in ['stage1', 'stage2']:
            self.assertTrue(
                gfile.exists(output_stem + '_' + stage +
                             '_original_known_error-00000-of-00001.dat'))
            self.assertTrue(
                gfile.exists(output_stem + '_' + stage +
                             '_original_unknown_error-00000-of-00001.dat'))
            self.assertTrue(
                gfile.exists(output_stem + '_' + stage +
                             '_mismatched_original-00000-of-00001.dat'))
            self.assertTrue(
                gfile.exists(output_stem + '_' + stage +
                             '_mismatched_regen-00000-of-00001.dat'))

        # Check the merge conflicts file
        with gfile.GFile(output_stem + '_conflicts-00000-of-00001.csv') as f:
            conflicts_lines = f.readlines()
            self.assertIn('conformer_id,', conflicts_lines[0])
            self.assertEqual(
                conflicts_lines[1], '618451001,1,1,1,1,'
                '-406.51179,9.999999,-406.522079,9.999999,True,True,'
                '-406.51179,0.052254,-406.522079,2.5e-05,True,True\n')

        # Check a couple of the stats.
        with gfile.GFile(output_stem + '_stats-00000-of-00001.csv') as f:
            stats_lines = f.readlines()
            self.assertIn('errors.status,0,2\n', stats_lines)
            self.assertIn('errors.warn_t1,0,4\n', stats_lines)
            self.assertIn('fate,FATE_SUCCESS,2\n', stats_lines)
            self.assertIn('fate,FATE_DUPLICATE_DIFFERENT_TOPOLOGY,1\n',
                          stats_lines)
            self.assertIn('num_initial_geometries,1,4\n', stats_lines)
            self.assertIn('num_duplicates,1,1\n', stats_lines)
            self.assertIn('zero_field,single_point_energy_pbe0d3_6_311gd,1\n',
                          stats_lines)

        # Check the smiles comparison output
        with gfile.GFile(output_stem +
                         '_smiles_compare-00000-of-00001.csv') as f:
            smiles_lines = f.readlines()
            self.assertIn(
                '620517002,MISMATCH,NotAValidSmilesString,'
                '[H]C1=C2OC2=C(F)O1,FC1=C2OC2=CO1\n', smiles_lines)
            # Make sure that a bond topology with a matching smiles doesn't show
            for line in smiles_lines:
                self.assertNotIn('618451001', line)

        # Check the bond topology summary
        with gfile.GFile(output_stem + '_bt_summary-00000-of-00001.csv') as f:
            bt_summary_lines = f.readlines()
            # Check part of the header line
            self.assertIn('bt_id', bt_summary_lines[0])
            self.assertIn('count_attempted_conformers', bt_summary_lines[0])
            # This is the bond topology that has no conformer
            self.assertIn('10,0,0,0,0,0,0,0,0,0,0,0,0,0\n', bt_summary_lines)
            # This is a bond topology with 1 conformer
            self.assertIn('620517,1,0,0,0,1,0,1,0,0,0,0,0,0\n',
                          bt_summary_lines)
            # This is a bond topology with 2 conformers
            self.assertIn('618451,2,0,0,0,2,0,0,0,2,0,0,0,0\n',
                          bt_summary_lines)

        # Check the bond lengths file
        with gfile.GFile(output_stem + '_bond_lengths.csv') as f:
            bond_length_lines = f.readlines()
            self.assertEqual(
                'atom_char_0,atom_char_1,bond_type,length_str,count\n',
                bond_length_lines[0])
            self.assertIn('c,c,2,1.336,1\n', bond_length_lines)
            self.assertIn('c,o,1,1.422,2\n', bond_length_lines)

        # For the gzip files below, we check >100 because even an empty gzip file
        # has non-zero length. 100 is kind of arbitrary to be bigger than the
        # expected header of 20.

        # Check that the generated TFRecord files contain some expected outputs
        standard_dataset = tf.data.TFRecordDataset(
            output_stem + '_standard_tfrecord-00000-of-00001')
        standard_output = [
            dataset_pb2.Conformer.FromString(raw)
            for raw in standard_dataset.as_numpy_iterator()
        ]
        self.assertCountEqual([c.conformer_id for c in standard_output],
                              [618451001, 618451123])
        # Check that fields are filtered the way we expect
        self.assertFalse(
            standard_output[0].properties.HasField('compute_cluster_info'))
        self.assertFalse(
            standard_output[0].properties.HasField('homo_pbe0_aug_pc_1'))
        self.assertTrue(
            standard_output[0].properties.HasField('rotational_constants'))

        complete_dataset = tf.data.TFRecordDataset(
            output_stem + '_complete_tfrecord-00000-of-00001')
        complete_output = [
            dataset_pb2.Conformer.FromString(raw)
            for raw in complete_dataset.as_numpy_iterator()
        ]
        self.assertCountEqual([c.conformer_id for c in complete_output],
                              [618451001, 618451123, 620517002, 79593005])
        # Check that fields are filtered the way we expect
        # The DirectRunner randomizes the order of output so we need to make sure
        # that we get a full record.
        complete_entry = [
            c for c in complete_output if c.conformer_id == 618451001
        ][0]
        self.assertFalse(
            complete_entry.properties.HasField('compute_cluster_info'))
        self.assertTrue(
            complete_entry.properties.HasField('homo_pbe0_aug_pc_1'))
        self.assertTrue(
            complete_entry.properties.HasField('rotational_constants'))

        complete_entry_for_smiles = [
            c for c in complete_output if c.conformer_id == 620517002
        ][0]
        self.assertEqual(complete_entry_for_smiles.properties.smiles_openbabel,
                         'NotAValidSmilesString')
示例#30
0
def get_policy_model_files(output_dir):
    return list(
        reversed(
            sorted(gfile.glob(os.path.join(output_dir, "model-??????.pkl")))))