Exemple #1
0
    def load_from_directory(trajectory_dir, epoch=None, n_trajectories=None):
        """Load trajectories from specified dir and epoch."""

        trajectory_file_glob = TRAJECTORY_FILE_GLOB

        # If there is a desired epoch, modify the glob to get that instead.
        if epoch:
            trajectory_file_glob = trajectory_file_glob.replace(
                "epoch_*", "epoch_%06d" % epoch)

        trajectory_files = gfile.glob(
            os.path.join(trajectory_dir, trajectory_file_glob))

        if not trajectory_files:
            return None

        # We read and load all the files, revisit if this becomes a problem.
        trajectories_buffer = []
        completed_trajectories_buffer = []
        for trajectory_file in trajectory_files:
            with gfile.GFile(trajectory_file, "rb") as f:
                list_trajectories = pickle.load(f)
                assert isinstance(list_trajectories, list)
                if not list_trajectories:
                    continue
                assert isinstance(list_trajectories[0], Trajectory)
                for trajectory in list_trajectories:
                    if trajectory.done:
                        completed_trajectories_buffer.append(trajectory)
                    else:
                        trajectories_buffer.append(trajectory)

        if not trajectories_buffer and not completed_trajectories_buffer:
            return None

        # Randomly sample `n_trajectories` if needed.
        n_trajectories = None if not n_trajectories else int(n_trajectories)
        if n_trajectories and n_trajectories > 0:
            trajectories_buffer = list(
                np.random.choice(trajectories_buffer,
                                 int(trajectories_buffer)))

        # Construct and return a new BatchTrajectory object.
        return BatchTrajectory(
            batch_size=len(trajectories_buffer),
            trajectories=trajectories_buffer,
            completed_trajectories=completed_trajectories_buffer)
Exemple #2
0
def maybe_restore_opt_state(output_dir, policy_and_value_opt_state,
                            policy_and_value_state):
    """Maybe restore the optimization state from the checkpoint dir.

  Optimization state includes parameters and optimizer slots.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_opt_state: Default optimization state, returned if model
      isn't found.
    policy_and_value_state: state of the policy and value network.

  Returns:
    tuple (restored (bool), opt_state, state, epoch (int),
    opt_step (int)) where epoch is the epoch from which we restored the
    optimization state, 0 is restored = False, and opt_step is the total
    optimization step (sum of all optimization steps made up to the current
    epoch).
  """
    restored = False
    epoch = 0
    total_opt_step = 0
    model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl"))
    for model_file in reversed(sorted(model_files)):
        logging.info("Trying to restore model from %s", model_file)
        try:
            with gfile.GFile(model_file, "rb") as f:
                policy_and_value_opt_state, policy_and_value_state, total_opt_step = (
                    pickle.load(f))
            model_file_basename = os.path.basename(
                model_file)  # model-??????.pkl
            restored = True
            epoch = int(filter(str.isdigit, model_file_basename))
            break
        except EOFError as e:
            logging.error("Unable to load model from: %s with %s", model_file,
                          e)
            # Try an older version.
            continue
    return (
        restored,
        policy_and_value_opt_state,
        policy_and_value_state,
        epoch,
        total_opt_step,
    )
Exemple #3
0
def save_opt_state(output_dir,
                   policy_and_value_opt_state,
                   policy_and_value_state,
                   epoch,
                   total_opt_step):
  """Saves the policy and value network optimization state etc."""
  pkl_module = utils.get_pickle_module()
  old_model_files = get_policy_model_files(output_dir)
  params_file = os.path.join(output_dir, "model-%06d.pkl" % epoch)
  with gfile.GFile(params_file, "wb") as f:
    pkl_module.dump(
        (policy_and_value_opt_state, policy_and_value_state, total_opt_step), f)
  # Remove the old model files, leave the latest one (it might be in the
  # process of getting read async) -- this will get cleaned up later.
  for path in old_model_files[1:]:
    if path != params_file:
      gfile.remove(path)
def main(argv):
  del argv  # Unused.

  smu_proto = dataset_pb2.MultipleMolecules()
  with gfile.GFile(FLAGS.input_file) as f:
    raw_proto = f.read()
  text_format.Parse(raw_proto, smu_proto)
  smu_writer = SmuWriter(FLAGS.annotate)
  contents = ''.join(
      smu_writer.process_stage2_proto(molecule)
      for molecule in smu_proto.molecules)
  if FLAGS.output_file:
    logging.info('Writing smu7 molecules to .dat file %s.', FLAGS.output_file)
    with open(FLAGS.output_file, 'w') as f:
      f.write(contents)
  else:
    print(contents, end='')
Exemple #5
0
 def process(
     self, question_answer: QuestionAnswer
 ) -> Generator[QuestionAnswerEvidence, None, None]:
     for info in question_answer.evidence_info:
         if info.source == 'EntityPages':
             evidence_path = os.path.join(self._wikipedia_dir, info.id)
         elif info.source == 'SearchResult':
             evidence_path = os.path.join(self._web_dir, info.id)
         else:
             raise ValueError(f'Unknown evidence source: {info.source}.')
         with gfile.GFile(evidence_path, 'rb') as f:
             text = f.read().decode('utf-8')
         metrics.Metrics.counter('_', 'documents').inc()
         yield QuestionAnswerEvidence(question=question_answer.question,
                                      evidence=Evidence(info=info,
                                                        text=text),
                                      answer=question_answer.answer)
 def _write_episodes(self):
     """Write recorded episodes to pickle."""
     # Inclusive first and last.
     episode_idx = self._get_episode_idx()
     if self._eval_mode:
         counts = self._counter.get_counts()
         train_steps = counts['steps'] if 'steps' in counts else 0
         filename = f'evaluation_{train_steps}.pkl'
     else:
         first_ep = episode_idx - len(self._episodes_to_record) + 1
         last_ep = episode_idx
         filename = f'episodes_{first_ep}-{last_ep}.pkl'
     log_path = os.path.join(self._logdir, filename)
     print('Episode', episode_idx, ': flushing to', log_path)
     with gfile.GFile(log_path, 'wb') as f:
         pickle.dump(self._episodes_to_record, f)
     self._episodes_to_record = []
Exemple #7
0
def set_eval_paths(ckpt_dir, ckpt, custom_eval_id):
    """Set paths for evaluation and TensorBoard summaries."""
    eval_path, summary_key, best_epoch = None, None, None
    if ckpt is not None:
        best_epoch_path = (ckpt.replace('ckpt_best_of_',
                                        'best_epoch_of_').replace(
                                            'ckpt', 'best_epoch'))
        best_epoch_path = os.path.join(ckpt_dir, best_epoch_path)
        if gfile.exists(best_epoch_path):
            with gfile.GFile(best_epoch_path) as f:
                best_epoch = int(f.read())
        else:
            best_epoch = int(ckpt.replace('ckpt_', ''))
        print('best epoch:', best_epoch)
        eval_path = set_eval_path(ckpt_dir, custom_eval_id, ckpt)
        summary_key = set_summary_key(custom_eval_id, ckpt)
    return eval_path, summary_key, best_epoch
Exemple #8
0
def load_trajectories(trajectory_dir, eval_frac):
    """Loads trajectories from a possibly nested directory of pickles."""
    pkl_module = utils.get_pickle_module()
    train_trajectories = []
    eval_trajectories = []
    # Search the entire directory subtree for trajectories.
    for (subdir, _, filenames) in gfile.walk(trajectory_dir):
        for filename in filenames:
            shard_path = os.path.join(subdir, filename)
            with gfile.GFile(shard_path, "rb") as f:
                trajectories = pkl_module.load(f)
                pivot = int(len(trajectories) * (1 - eval_frac))
                train_trajectories.extend(trajectories[:pivot])
                eval_trajectories.extend(trajectories[pivot:])
    assert train_trajectories, "Haven't found any training data."
    assert eval_trajectories, "Haven't found any evaluation data."
    return (train_trajectories, eval_trajectories)
Exemple #9
0
def save_data(df, path, out_data, mode, bucket):
    if mode == 'cloud':
        out_csv_gcs = f'{bucket}/{path}/{out_data}'
        logging.info(f'Writing {out_csv_gcs} file...')
        with gfile.GFile(name=out_csv_gcs, mode='w') as file:
            df.to_csv(file, index=False)
        logging.info(f'{out_csv_gcs} successfully loaded!')
        return out_csv_gcs
    else:
        p = Path(path)
        if not p.exists():
            os.mkdir(path)
        out_csv = f'{path}/{out_data}'
        logging.info(f'Writing {out_csv} file...')
        df.to_csv(out_csv, index=False)
        logging.info(f'{out_csv} successfully loaded!')
        return out_csv
Exemple #10
0
def log_pytree_shape_and_statistics(pytree, json_path=None):
    """Logs the shape and norm of every array in the pytree."""
    if not pytree:
        absl_logging.info('Empty pytree')
        return

    if json_path:
        shape_dict = jax.tree_map(lambda x: x.shape, pytree).pretty_repr()
        with gfile.GFile(json_path, 'w') as json_file:
            json_file.write(shape_dict)

    absl_logging.info('Printing model param shapes.')
    shape_dict = jax.tree_map(_summary_str, pytree)
    absl_logging.info(shape_dict.pretty_repr())
    total_params = jax.tree_util.tree_reduce(
        operator.add, jax.tree_map(lambda x: x.size, pytree))
    absl_logging.info('Total params: %d', total_params)
def load_dataset_metadata(metadata_filename):
    """Helper function to load dataset metadata.

  Args:
    metadata_filename: Filename containing dataset metadata.

  Returns:
    Padding configuration and edge types for the dataset.
  """
    with gfile.GFile(metadata_filename, "r") as fp:
        metadata = json.load(fp)

    edge_types = metadata["edge_types"]
    padding_config = flax.serialization.from_state_dict(
        target=jax_util.synthesize_dataclass(graph_bundle.PaddingConfig),
        state=metadata["padding_config"])
    return padding_config, edge_types
def pipeline(root):
  """Beam pipeline.

  Args:
    root: the root of the pipeline.
  """
  _ = (
      root
      | 'CreateTopologies' >> beam.Create(
          smu_utils_lib.generate_bond_topologies_from_csv(
              gfile.GFile(FLAGS.input_bond_topology_csv, 'r')))
      | 'Reshuffle1' >> beam.Reshuffle()
      | 'CheckInvariance' >> beam.FlatMap(check_smiles_permutation_invariance)
      | 'Reshuffle2' >> beam.Reshuffle()
      | 'CSVFormat' >> beam.Map(lambda vals: ','.join(str(x) for x in vals))
      | 'WriteOutput' >> beam.io.WriteToText(
          FLAGS.output_csv, header='bt_id,smiles0,smiles1', num_shards=1))
Exemple #13
0
def restore_state(output_dir):
    """Restore State."""
    params_file = os.path.join(output_dir, "model.pkl")
    if not gfile.exists(params_file):
        return State(step=None,
                     opt_state=None,
                     history=trax_history.History(),
                     model_state=None)

    with gfile.GFile(params_file, "rb") as f:
        (opt_state, step, history, model_state) = pickle.load(f)
    log("Model loaded from %s at step %d" % (params_file, step))
    logging.debug("From loaded model : history = %s", history)
    return State(step=step,
                 opt_state=OptState(*opt_state),
                 history=history,
                 model_state=model_state)
 def generate_val_timestep(self):
     """Generator to iterate over the validation split."""
     demo_idx = 0
     val_pointer = gfile.GFile(self.path, 'rb')
     while True:
         try:
             if not self.in_memory:
                 demo = pickle.load(val_pointer)
             if self.demo_in_val(demo_idx):
                 # Preprocess only if (at least some time step) in val split.
                 if self.in_memory:
                     observations = self.observations[demo_idx]
                     actions = self.actions[demo_idx]
                 else:
                     observations = demo['observations']
                     actions = demo['actions']
                     if self.max_demo_length is not None:
                         observations = observations[:self.max_demo_length]
                         actions = actions[:self.max_demo_length]
                 if not self.in_memory or not self._decompress_once:
                     # Decompress images of all viewpoints (if applicable) to keep
                     # viewpoint constant across stacked frames.
                     observations = self.decompress_all_images(observations)
                 signals = [None for _ in observations]
                 if FLAGS.clip_actions:
                     actions = list(np.clip(actions, -1, 1))
                 if self.agent is not None:
                     observations, signals, actions = self.agent.normalize_demo(
                         observations,
                         actions,
                         augment_frames=self.augment_frames,
                         randomize_camera=True)
                 for t in range(len(actions)):
                     if self.split_by_demo:
                         if not self.episode_train_split[demo_idx]:
                             yield observations[t], signals[t], actions[
                                 t], demo_idx, t
                     elif not self.episode_train_split[demo_idx][t]:
                         yield observations[t], signals[t], actions[
                             t], demo_idx, t
             demo_idx += 1
             if self.max_to_load is not None and demo_idx >= self.max_to_load:
                 return
         except EOFError:
             return
Exemple #15
0
 def test_save_restore_checkpoints_w_float_steps(self):
     tmp_dir = self.create_tempdir().full_path
     test_object0 = {
         'a': np.array([0, 0, 0], np.int32),
         'b': np.array([0, 0, 0], np.int32)
     }
     test_object1 = {
         'a': np.array([1, 2, 3], np.int32),
         'b': np.array([1, 1, 1], np.int32)
     }
     test_object2 = {
         'a': np.array([4, 5, 6], np.int32),
         'b': np.array([2, 2, 2], np.int32)
     }
     # Create leftover temporary checkpoint, which should be ignored.
     gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 0.0,
                                 prefix='test_',
                                 keep=1)
     self.assertIn('test_0.0', os.listdir(tmp_dir))
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 2.0,
                                 prefix='test_',
                                 keep=1)
     with self.assertRaises(errors.InvalidCheckpointError):
         checkpoints.save_checkpoint(tmp_dir,
                                     test_object2,
                                     1.0,
                                     prefix='test_',
                                     keep=1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 3.0,
                                 prefix='test_',
                                 keep=2)
     self.assertIn('test_3.0', os.listdir(tmp_dir))
     self.assertIn('test_2.0', os.listdir(tmp_dir))
     jtu.check_eq(new_object, test_object1)
Exemple #16
0
 def dump_tabular(self):
     df = pd.DataFrame(self.unsaved, index=self.unsaved[self.index_key])
     if self.rank == 0:
         if self.print_count % 10 == 0:
             print(df.round(3), flush=True)
         else:
             print(df.round(3).to_string().split("\n")[1], flush=True)
     self.print_count += 1
     if self.args.backend == 'local':
         df.to_csv(self.filename,
                   sep='\t',
                   mode='a',
                   header=self.use_headers)
     else:
         with gfile.GFile(self.filename, 'a') as f:
             df.to_csv(f, sep='\t', mode='a', header=self.use_headers)
     self.unsaved = {}
     self.use_headers = False
Exemple #17
0
def load_trainer_state(output_dir):
    """Returns a TrainerState instance loaded from the given `output_dir`."""
    params_file = os.path.join(output_dir, 'model.pkl')
    if not gfile.exists(params_file):
        return TrainerState(step=None,
                            opt_state=None,
                            history=trax_history.History(),
                            model_state=None)

    pkl_module = utils.get_pickle_module()
    with gfile.GFile(params_file, 'rb') as f:
        (opt_state, step, history, model_state) = pkl_module.load(f)
    log('Model loaded from %s at step %d' % (params_file, step))
    logging.debug('From loaded model : history = %s', history)
    return TrainerState(step=step,
                        opt_state=OptState(*opt_state),
                        history=history,
                        model_state=model_state)
Exemple #18
0
def log_pytree_shape_and_statistics(pytree, json_path=None):
    """Logs the shape and norm of every array in the pytree."""
    if not pytree:
        absl_logging.info('Empty pytree')
        return

    if json_path:
        shape_dict = json.dumps(jax.tree_map(lambda x: x.shape, pytree))
        with gfile.GFile(json_path, 'w') as json_file:
            json_file.write(shape_dict)

    absl_logging.info('Printing model param shapes.')
    shape_dict = jax.tree_map(_summary_str, pytree)
    # We use json.dumps for pretty printing nested dicts.
    absl_logging.info(json.dumps(shape_dict, sort_keys=True, indent=4))
    total_params = jax.tree_util.tree_reduce(
        operator.add, jax.tree_map(lambda x: x.size, pytree))
    absl_logging.info('Total params: %d', total_params)
 def add_demos(self, path, max_to_load=None):
     """Add additional demos to training set without affecting dataset stats."""
     # TODO(minttu): not-in-memory option
     self._stats_fixed = True
     if self.max_to_load is not None and FLAGS.stats_from_large_dataset:
         self.observations = self.observations[:self.max_to_load]
         self.actions = self.actions[:self.max_to_load]
     length_before = len(self.observations)
     observations = self.observations
     actions = self.actions
     lengths = self.episode_lengths
     num_new_demos = 0
     print('Before adding demos')
     print(len(self.observations), len(self.actions),
           len(self.episode_train_split))
     with gfile.GFile(path, 'rb') as f:
         while True:
             try:
                 demo = pickle.load(f)
                 num_new_demos += 1
                 obs_demo = demo['observations']
                 act_demo = demo['actions']
                 if self.max_demo_length is not None:
                     obs_demo = obs_demo[:self.max_demo_length]
                     act_demo = act_demo[:self.max_demo_length]
                 observations.append(obs_demo)
                 actions.append(np.stack(act_demo))
                 lengths.append(len(obs_demo))
                 if self.split_by_demo:
                     self.episode_train_split = np.concatenate(
                         [self.episode_train_split,
                          np.ones(1)])
                 else:
                     self.episode_train_split = np.concatenate(
                         [self.episode_train_split,
                          np.ones(len(obs_demo))])
                 if max_to_load is not None and num_new_demos >= max_to_load:
                     break
             except EOFError:
                 break
     assert len(self.observations) == length_before + num_new_demos
     print('Addeed', num_new_demos, 'demos')
     print(len(self.observations), len(self.actions),
           len(self.episode_train_split))
  def load_ihdp(self):
    """Loads semi-synthetic data.

    It updates the object DataSimulation.

    Args:
      self
    Returns:
      None
    """
    self.data_path = self.param_data['data_path'] + 'IHDP/'
    # Reference: https://github.com/AMLab-Amsterdam/CEVAE
    # each iteration, it randomly pick one of the 10 existing repetitions
    np.random.seed(self.seed)

    i = np.random.randint(1, 10, 1)[0]
    path = self.data_path + '/ihdp_npci_' + str(i) + '.csv.txt'
    with gfile.GFile(path, 'r') as f:
      data = np.loadtxt(f, delimiter=',')

    self.outcome, y_cf = data[:, 1][:, np.newaxis], data[:, 2][:, np.newaxis]
    self.outcome = self.outcome.ravel()
    self.treatment = data[:, 0].ravel()
    self.covariates = data[:, 5:]
    scaler = StandardScaler()
    self.covariates = scaler.fit_transform(self.covariates)

    self.sample_size, self.num_covariates = self.covariates.shape
    self.linear, self.noise = False, False
    self.var_covariates = None
    self.treatment_prop = self.treatment.sum()/len(self.treatment)

    y1, y0 = self.outcome, self.outcome
    y1 = [
        y_cf[j][0] if item == 0 else self.outcome[j]
        for j, item in enumerate(self.treatment)
    ]
    y0 = [
        y_cf[j][0] if item == 1 else self.outcome[j]
        for j, item in enumerate(self.treatment)
    ]
    y1 = np.array(y1)
    y0 = np.array(y0)
    self.tau = (y1-y0).mean()
Exemple #21
0
def parse_duplicates_file(filename):
    """Parses duplciate file into a pandas dataframe.

  The duplciate file supplied by our collaborators (called
  list.equivalent_{isomers,conformers.dat) is a two column, space separated
  file of composite names like x07_n4o3h4.091404.073
  which we parse the names into columns
  * nameX: original composiite name from file
  * stoichX: string for the stoichiometry
  * btidX: bond topology id
  * shortconfidX: 3 digit conformer id
  * confidX: full conformer id that we use (btid * 1000 + shortconfid)
  (for X = 1 or 2)

  Args:
    filename: file to read (usually list.equivalent_isomers.dat)

  Returns:
    pd.DataFrame
  """
    with gfile.GFile(filename) as f:
        df_dups = pd.read_csv(f,
                              delim_whitespace=True,
                              names=['name1', 'name2'],
                              header=None)

    for idx in ['1', '2']:
        df_dups = pd.concat([
            df_dups,
            df_dups['name' +
                    idx].str.extract(r'x07_([\w\d]+)\.(\d+).(\d+)').rename(
                        columns={
                            0: 'stoich' + idx,
                            1: 'btid' + idx,
                            2: 'shortconfid' + idx
                        })
        ],
                            axis=1)
        df_dups['btid' + idx] = df_dups['btid' + idx].astype(int)
        df_dups['shortconfid' + idx] = df_dups['shortconfid' + idx].astype(int)
        df_dups['confid' + idx] = (df_dups['btid' + idx] * 1000 +
                                   df_dups['shortconfid' + idx])

    return df_dups
Exemple #22
0
def generate_examples(file_path: str):
  """Provides a common generate_examples method for D4RL datasets."""
  with gfile.GFile(file_path, 'rb') as f:
    dataset_file = h5py.File(f, 'r')
    dataset_dict = {}
    for k in _get_dataset_keys(dataset_file):
      try:
        # first try loading as an array
        dataset_dict[k] = dataset_file[k][:]
      except ValueError as e:  # try loading as a scalar
        dataset_dict[k] = dataset_file[k][()]
    dataset_file.close()
  if 'timeouts' not in dataset_dict:
    raise ValueError('Only datasets with explicit timeouts are supported.')

  done = [
      terminal or timeout
      for (terminal,
           timeout) in zip(dataset_dict['terminals'], dataset_dict['timeouts'])
  ]
  # is_first corresponds to the done flag delayed by one step.
  dataset_dict['is_first'] = [True] + done[:-1]

  # TODO(sabela): Add extra keys for metadata (qpos, qval, goal) that is only
  # present in some datasets.
  dataset_dict = {
      'observation': dataset_dict['observations'],
      'action': dataset_dict['actions'],
      'reward': dataset_dict['rewards'],
      'discount': np.ones_like(dataset_dict['rewards']),
      'is_terminal': dataset_dict['terminals'],
      'is_first': dataset_dict['is_first'],
  }
  num_steps = len(dataset_dict['is_first'])
  prev = 0
  counter = 0
  for pos in range(num_steps):
    if dataset_dict['is_first'][pos] and pos > prev:
      yield counter, _get_episode(dataset_dict, prev, pos)
      prev = pos
      counter += 1
  if prev < num_steps:
    yield counter, _get_episode(dataset_dict, prev, num_steps)
def gen_csv_from_annotations(
    input_dir: str,
    output_file=constants.DEFAULT_CSV_FILENAME,
    out_path_prefix='',
    dataset_type=constants.DEFAULT_DATASET_TYPE):
  """Generates AutoML dataset CSV from annotation files.

  Args:
    input_dir: Directory of annotation files.
    output_file: Output CSV filename.
    out_path_prefix: Filepath prefix to prepend to the image files.
      e.g.
      src_image_filename = '/tmp/path/to/image.jpg'
      out_path_prefix = 'gs://bucket/images'
      output_image_filename = 'gs://bucket/images/image.jpg'
    dataset_type: Dataset type (TRAIN, VAL, TEST, UNSPECIFIED)
      to use for all the parsed images.
  """

  if not gfile.exists(input_dir):
    raise ValueError('Input directory not found.')

  with gfile.GFile(os.path.expanduser(output_file), 'w') as outf:
    writer = csv.writer(outf, delimiter=',')
    for filename in gfile.listdir(os.path.expanduser(input_dir)):
      filepath = os.path.join(input_dir, filename)
      image_filename, boxes = annotation.read(filepath)
      out_image_filename = os.path.join(out_path_prefix, image_filename)
      for b in boxes:
        row = [
            dataset_type,
            out_image_filename,
            b.label,
            b.xmin,
            b.ymin,
            '',
            '',
            b.xmax,
            b.ymax,
            '',
            '',
        ]
        writer.writerow(row)
Exemple #24
0
def smiles_to_id(bond_topology_filename):
    """DoFn for creating the smiles to id mapping.

  Reads the same merged_bond_topology file as bond_topology_summaries_from_csv
  and output. We could of course produce them both at the same time, but this
  is simpler.

  Args:
    bond_topology_filename: see FLAGS.input_bond_topology_csv

  Yields:
    smiles, bond_topology_id
  """
    with gfile.GFile(bond_topology_filename, 'r') as infile:
        reader = csv.reader(iter(infile))
        next(reader)  # skip the header line
        for row in reader:
            bt_id, _, _, _, _, smiles = row
            yield smiles, int(bt_id)
Exemple #25
0
def maybe_restore_opt_state(output_dir,
                            policy_and_value_opt_state=None,
                            policy_and_value_state=None):
    """Maybe restore the optimization state from the checkpoint dir.

  Optimization state includes parameters and optimizer slots.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_opt_state: Default optimization state, returned if model
      isn't found.
    policy_and_value_state: state of the policy and value network.

  Returns:
    tuple (opt_state, state, epoch (int), opt_step (int)) where epoch is the
    epoch from which we restored the optimization state, 0 if no checkpoint was
    found, and opt_step is the total optimization step (sum of all optimization
    steps made up to the current epoch).
  """
    pkl_module = utils.get_pickle_module()
    epoch = 0
    total_opt_step = 0
    history = trax_history.History()
    for model_file in get_policy_model_files(output_dir):
        logging.info('Trying to restore model from %s', model_file)
        try:
            with gfile.GFile(model_file, 'rb') as f:
                (policy_and_value_opt_state, policy_and_value_state,
                 total_opt_step, history) = pkl_module.load(f)
            epoch = get_epoch_from_policy_model_file(model_file)
            break
        except EOFError as e:
            logging.error('Unable to load model from: %s with %s', model_file,
                          e)
            # Try an older version.
            continue
    return (
        policy_and_value_opt_state,
        policy_and_value_state,
        epoch,
        total_opt_step,
        history,
    )
Exemple #26
0
def save_checkpoint(ckpt_dir,
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1):
  """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str: path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.

  Returns:
    Filename of saved checkpoint.
  """
  # Write temporary checkpoint file.
  logging.info('Saving checkpoint at step: %s', step)
  ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
  ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
  gfile.makedirs(os.path.dirname(ckpt_path))
  with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
    fp.write(serialization.to_bytes(target))

  # Rename once serialization and writing finished.
  gfile.rename(ckpt_tmp_path, ckpt_path)
  logging.info('Saved checkpoint at %s', ckpt_path)

  # Remove old checkpoint files.
  base_path = os.path.join(ckpt_dir, f'{prefix}')
  checkpoint_files = natural_sort(gfile.glob(base_path + '*'))
  if len(checkpoint_files) > keep:
    old_ckpts = checkpoint_files[:-keep]
    for path in old_ckpts:
      logging.info('Removing checkpoint at %s', path)
      gfile.remove(path)

  return ckpt_path
  def from_file(cls, filename, right_tail_mass):
    """Load the distribution from a file.

    The file should be a CSV with two (unnamed) columns. The first is a length
    bucket and the second is a count. Each line represents the number of
    observed lengths between the length on tha line and the next line. The
    buckets should be equally spaced..

    Args:
      filename: file to read
      right_tail_mass: probability mass added past the right side of the buckets
        (see class documentation)

    Returns:
      EmpiricalLengthDistribution
    """
    with gfile.GFile(filename) as f:
      df = pd.read_csv(f, header=None, names=['length', 'count'], dtype=float)

    return EmpiricalLengthDistribution(df, right_tail_mass)
 def keep_latest_trajectories(self, demos_file, num_to_keep):
     # Keep num_to_keep shortest trajectories in the dataset at demos_file.
     print(demos_file)
     all_demos_file = (demos_file.replace(f'e{num_to_keep}',
                                          '').replace('.pkl', 'all.pkl'))
     print(all_demos_file)
     gfile.Rename(demos_file, all_demos_file)
     last_demos = []
     with gfile.GFile(all_demos_file, 'rb') as f:
         while True:
             try:
                 demo = pickle.load(f)
                 last_demos.append(demo)
                 last_demos = last_demos[:num_to_keep]
             except EOFError:
                 break
     new_demo_writer = pickle_dataset.DemoWriter(demos_file)
     for demo in last_demos:
         new_demo_writer.write_episode(demo['observations'],
                                       demo['actions'])
def main(unused_argv):
    config = json_utils.json_file_to_dict(FLAGS.config)
    wrapper = inference_utils.get_inference_wrapper(config, FLAGS.rules,
                                                    FLAGS.target_grammar,
                                                    FLAGS.verbose)
    _ = inference_utils.get_checkpoint(wrapper, FLAGS.model_dir,
                                       FLAGS.checkpoint)
    examples = tsv_utils.read_tsv(FLAGS.input)

    num_predictions_match = 0
    predictions = []
    for idx, example in enumerate(examples):
        if FLAGS.offset and idx < FLAGS.offset:
            continue
        if FLAGS.limit and idx >= FLAGS.limit:
            break

        if FLAGS.verbose:
            print("Processing example %s: (%s, %s)" %
                  (idx, example[0], example[1]))

        source = example[0]
        original_target = example[1]

        predicted_target = inference_parser.get_top_output(source, wrapper)
        if FLAGS.verbose:
            print("predicted_target: %s" % predicted_target)

        if predicted_target == original_target:
            num_predictions_match += 1
        else:
            if FLAGS.verbose:
                print("predictions do not match.")

        predictions.append(predicted_target)

    print("num_predictions_match: %s" % num_predictions_match)

    with gfile.GFile(FLAGS.output, "w") as txt_file:
        for prediction in predictions:
            txt_file.write("%s\n" % prediction)
Exemple #30
0
def maybe_restore_params(output_dir, policy_and_value_net_params):
    """Maybe restore the params from the checkpoint dir.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_net_params: Default params, returned if model is'nt found.

  Returns:
    triple (restore (bool), params, iter(int)) where iter is the epoch from
    which we restored the params, 0 is restore = False.
  """
    model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl"))
    if not model_files:
        return False, policy_and_value_net_params, 0

    model_file = sorted(model_files)[-1]
    model_file_basename = os.path.basename(model_file)  # model-??????.pkl
    i = int(filter(str.isdigit, model_file_basename))
    with gfile.GFile(model_file, "rb") as f:
        policy_and_value_net_params = pickle.load(f)
    return True, policy_and_value_net_params, i