def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'): """Restore last/best checkpoint from checkpoints in path. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: ckpt_1, ckpt_2, ckpt_3 --> ckpt_3 ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1 ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5 Args: ckpt_dir: str: directory of checkpoints to restore from. target: matching object to rebuild via deserialized state-dict. step: int: step number to load or None to load latest. prefix: str: name prefix of checkpoint files. Returns: Restored `target` updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in `target` unchanged. """ if step: ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) if not gfile.exists(ckpt_path): raise ValueError(f'Matching checkpoint not found: {ckpt_path}') else: glob_path = os.path.join(ckpt_dir, f'{prefix}*') checkpoint_files = natural_sort(gfile.glob(glob_path)) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path] if not checkpoint_files: return target ckpt_path = checkpoint_files[-1] logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: return serialization.from_bytes(target, fp.read())
def __init__(self, path, dataset, *args, **kwargs): self.dataset = dataset self.vocab = Vocab(*args, **kwargs) if self.dataset in ["ptb", "wt2", "enwik8", "text8"]: self.vocab.count_file(os.path.join(path, "train.txt")) self.vocab.count_file(os.path.join(path, "valid.txt")) self.vocab.count_file(os.path.join(path, "test.txt")) elif self.dataset == "wt103": self.vocab.count_file(os.path.join(path, "train.txt")) elif self.dataset == "lm1b": train_path_pattern = os.path.join( path, "1-billion-word-language-modeling-benchmark-r13output", "training-monolingual.tokenized.shuffled", "news.en-*") train_paths = glob(train_path_pattern) # the vocab will load from file when build_vocab() is called # for train_path in sorted(train_paths): # self.vocab.count_file(train_path, verbose=True) self.vocab.build_vocab() if self.dataset in ["ptb", "wt2", "wt103"]: self.train = self.vocab.encode_file(os.path.join( path, "train.txt"), ordered=True) self.valid = self.vocab.encode_file(os.path.join( path, "valid.txt"), ordered=True) self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True) elif self.dataset in ["enwik8", "text8"]: self.train = self.vocab.encode_file(os.path.join( path, "train.txt"), ordered=True, add_eos=False) self.valid = self.vocab.encode_file(os.path.join( path, "valid.txt"), ordered=True, add_eos=False) self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True, add_eos=False) elif self.dataset == "lm1b": self.train = train_paths valid_path = os.path.join(path, "valid.txt") test_path = valid_path self.valid = self.vocab.encode_file(valid_path, ordered=True, add_double_eos=True) self.test = self.vocab.encode_file(test_path, ordered=True, add_double_eos=True) if self.dataset == "wt103": self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)] elif self.dataset == "lm1b": self.cutoffs = [0, 60000, 100000, 640000] + [len(self.vocab)] else: self.cutoffs = []
def maybe_restore_params(output_dir, policy_and_value_net_params, state): """Maybe restore the params from the checkpoint dir. Args: output_dir: Directory where saved model checkpoints are stored. policy_and_value_net_params: Default params, returned if model is'nt found. state: policy state. Returns: tuple (restore (bool), params, state, iter (int), opt_step (int)) where iter is the epoch from which we restored the params, 0 is restore = False, and opt_step is the total optimization step (sum of all optimization steps made up to the current epoch). """ model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl")) for model_file in reversed(sorted(model_files)): logging.info("Trying to restore model from %s", model_file) try: with gfile.GFile(model_file, "rb") as f: loaded_policy_and_value_net_params, loaded_state, total_opt_step = ( pickle.load(f)) policy_and_value_net_params = loaded_policy_and_value_net_params state = loaded_state model_file_basename = os.path.basename( model_file) # model-??????.pkl i = int(filter(str.isdigit, model_file_basename)) return True, policy_and_value_net_params, state, i, total_opt_step except EOFError as e: logging.error("Unable to load model from: %s with %s", model_file, e) # Try an older version. continue return False, policy_and_value_net_params, state, 0, 0
def load_weights(self, base_fn): """Find the latest checkpoint matching base_fn, and load the weights.""" matcher = base_fn + "_*.npy" filenames = sorted(gfile.glob(matcher), reverse=True) assert len(filenames) > 0, "No files matching {}".format(matcher) filename = filenames[0] # load array with gfile.GFile(filename, "rb") as fin: serialized_weights = np.load(fin) print(serialized_weights.shape, self.all_weights_flat_sizes) all_weights_flat_split = tf.split(serialized_weights, self.all_weights_flat_sizes) all_weights_flat = [ tf.reshape(t, s) for t, s in zip(all_weights_flat_split, self.all_weights_flat_shapes) ] all_weights = tf.nest.pack_sequence_as(self.all_weights_structure, all_weights_flat) all_layers = self.layers + [self.loss_layer] if self.shared_params is not None: all_layers += list(self.shared_params.values()) for l, lw in zip(all_layers, all_weights): l.load_weights(lw)
def maybe_restore_params(output_dir, policy_and_value_net_params): """Maybe restore the params from the checkpoint dir. Args: output_dir: Directory where saved model checkpoints are stored. policy_and_value_net_params: Default params, returned if model is'nt found. Returns: triple (restore (bool), params, iter(int)) where iter is the epoch from which we restored the params, 0 is restore = False. """ model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl")) for model_file in reversed(sorted(model_files)): logging.info("Trying to restore model from %s", model_file) try: with gfile.GFile(model_file, "rb") as f: loaded_policy_and_value_net_params = pickle.load(f) policy_and_value_net_params = loaded_policy_and_value_net_params model_file_basename = os.path.basename( model_file) # model-??????.pkl i = int(filter(str.isdigit, model_file_basename)) return True, policy_and_value_net_params, i except EOFError as e: logging.error("Unable to load model from: %s with %s", model_file, e) # Try an older version. continue return False, policy_and_value_net_params, 0
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.get_absl_handler().use_absl_log_file() logging.info('Opening %s', FLAGS.output_sqlite) db = smu_sqlite.SMUSQLite(FLAGS.output_sqlite, 'c') if FLAGS.bond_topology_csv: logging.info('Starting smiles to btid inserts') smiles_id_dict = smu_utils_lib.smiles_id_dict_from_csv( open(FLAGS.bond_topology_csv)) db.bulk_insert_smiles(smiles_id_dict.items()) logging.info('Finished smiles to btid inserts') else: logging.info('Skipping smiles inserts') logging.info('Starting main inserts') dataset = tf.data.TFRecordDataset(gfile.glob(FLAGS.input_tfrecord)) db.bulk_insert((raw.numpy() for raw in dataset), batch_size=10000) logging.info('Starting vacuuming') db.vacuum() logging.info('Vacuuming finished')
def iterate_checkpoints(checkpoint_dir, min_global_step, max_global_step): """Iterates over all checkpoints in the interval [lb, ub).""" for checkpoint_path in gfile.glob(os.path.join(checkpoint_dir, 'ckpt_*')): step = int(checkpoint_path.split('_')[-1]) if min_global_step is None or (min_global_step <= step < max_global_step): full_path = os.path.join(checkpoint_dir, checkpoint_path) yield full_path, step
def merge_shared_tsvs(filename): """Merge multiple tsv files into one.""" output_files = gfile.glob("%s-*-of-*" % filename) all_examples = [] for output_file in output_files: examples = read_tsv(output_file) all_examples.extend(examples) write_tsv(all_examples, filename)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.info('Opening %s', FLAGS.output_sqlite) db = smu_sqlite.SMUSQLite(FLAGS.output_sqlite, 'c') dataset = tf.data.TFRecordDataset(gfile.glob(FLAGS.input_tfrecord)) db.bulk_insert((raw.numpy() for raw in dataset), batch_size=10000)
def get_checkpoints_in_range(checkpoint_dir, lower_bound, upper_bound): """Get checkpoint paths in step range [lower_bound, upper_bound].""" checkpoint_paths = [] for checkpoint_path in gfile.glob( os.path.join(checkpoint_dir, 'ckpt_*')): ckpt_step = int(checkpoint_path.split('_')[-1]) if ckpt_step >= lower_bound and ckpt_step <= upper_bound: checkpoint_paths.append(os.path.join(checkpoint_dir, checkpoint_path)) return checkpoint_paths
def test_delete_old_checkpoints(self): """Test that old checkpoints are deleted.""" state1 = dict(params=self.params, global_step=5, completed_epochs=4) checkpoint.save_checkpoint(self.test_dir, 0, state1, max_to_keep=1) state2 = dict(params=self.params, global_step=10, completed_epochs=8) checkpoint.save_checkpoint(self.test_dir, 1, state2, max_to_keep=1) dir_contents = gfile.glob(os.path.join(self.test_dir, '*')) self.assertLen(dir_contents, 1)
def write_backgrounds_csv(new_dataset_dir, backgrounds_dir): """Write CSV mapping background ints to bg filenames.""" bg_filenames = gfile.glob(path.join(backgrounds_dir, '*')) bg_filenames = [fname.split('/')[-1] for fname in bg_filenames] csv_filepath = path.join(new_dataset_dir, 'backgrounds.csv') with open(csv_filepath, 'w') as f: writer = csv.writer(f) writer.writerow(['int', 'label']) for i, fname in enumerate(bg_filenames): writer.writerow([i, fname])
def merge_predictions(examples, filename): """Merge multiple predcition files into one.""" source_to_prediction = {} output_files = gfile.glob("%s-*-of-*" % filename) for output_file in output_files: predictions = tsv_utils.read_tsv(output_file) for prediction in predictions: source, predicted_target = prediction source_to_prediction[source] = predicted_target new_predictions = [] for example in examples: new_predictions.append((source_to_prediction[example[0]])) txt_utils.write_txt(new_predictions, filename)
def save(self): """Save the agent parameters.""" logging.vlog(1, "Epoch [% 6d] saving model.", self._epoch) old_model_files = gfile.glob( os.path.join(self._output_dir, "model-??????.pkl")) params_file = os.path.join(self._output_dir, "model-%06d.pkl" % self._epoch) with gfile.GFile(params_file, "wb") as f: pickle.dump(self._policy_and_value_net_params, f) # Remove the old model files. for path in old_model_files: gfile.remove(path) # Reset this number. self._n_trajectories_done = 0 self._last_saved_at = self._epoch
def _load_foregrounds(self, foregrounds_dir): """Loads foregrounds from a directory. Args: foregrounds_dir: path to directory containing foregrounds. Directory of the form `foregrounds_dir`/$OBJECT_CLASS/$FILE_NAME. Produces: self.fg_classes: a list of names of foreground object classes, e.g. ['ambulance', 'bagel', ...] self.num_fgs_per_class: a dict of the form {foreground_obj_class_name: num_fgs_in_that_class} self.fgs: a list of the form [fg0, fg1, ...] where the foregrounds are `PIL.PngImagePlugin.PngImageFile`s. self.fgs_dict: a dict of the form {fg_class_name: [img0, img1, ...]} where the images are `PIL.PngImagePlugin.PngImageFile`s. """ if not gfile.exists(foregrounds_dir): raise ValueError( f'Foregrounds directory {foregrounds_dir} does not exist.') fg_fnames = gfile.glob(path.join(foregrounds_dir, '*/*')) fg_labels = [x.split('/')[-2] for x in fg_fnames] # e.g. 'car', 'cow' self.fg_classes = sorted(list(set(fg_labels))) self.num_fgs_per_class = { fg_class: len(gfile.glob(path.join(foregrounds_dir, fg_class, '*'))) for fg_class in self.fg_classes } self.num_fgs_per_class_list = [ self.num_fgs_per_class[fg_class] for fg_class in self.fg_classes ] self.fgs = self._thread_pool.map(load_image, fg_fnames) self.fgs_dict = {fg_class: [] for fg_class in self.fg_classes} for i, label in enumerate(fg_labels): self.fgs_dict[label].append(self.fgs[i]) print('Foregrounds loaded.')
def iterate_checkpoints(self): """Iterates over all checkpoints.""" if self.ckpt_to_evaluate: step = int(self.ckpt_to_evaluate.split('_')[-1]) full_path = os.path.join(self.checkpoint_dir, self.ckpt_to_evaluate) yield full_path, step else: for checkpoint_path in gfile.glob( os.path.join(self.checkpoint_dir, 'ckpt_*')): step = int(checkpoint_path.split('_')[-1]) if self.min_step and step < int(self.min_step): continue full_path = os.path.join(self.checkpoint_dir, checkpoint_path) yield full_path, step
def read_dir(): if not os.path.isdir(SOUNDS_DIR): raise Exception('Sound directory with name \'' + SOUNDS_DIR + '\' not found!') data = [] for word in WANTED_WORDS: word_dir = SOUNDS_DIR + word if not os.path.isdir(word_dir): raise Exception('Sounds directory for \'' + word + '\' not found at ' + word_dir + '!') search_path = os.path.join(word_dir, '*.wav') for wav_path in gfile.glob(search_path): data.append({'word': word, 'file': wav_path}) return data
def __init__(self, path, add_eos, add_beos, *args, **kwargs): self.vocab = Vocab(*args, **kwargs) self.add_eos = add_eos self.add_beos = add_beos self.vocab.build_vocab() pattern = os.path.join(path, "train", "train.??") self.train = glob(pattern) self.valid = \ self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True, add_eos=self.add_eos, add_beos=self.add_beos) self.test = \ self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True, add_eos=self.add_eos, add_beos=self.add_beos)
def load_from_directory(trajectory_dir, epoch=None, n_trajectories=None): """Load trajectories from specified dir and epoch.""" trajectory_file_glob = TRAJECTORY_FILE_GLOB # If there is a desired epoch, modify the glob to get that instead. if epoch: trajectory_file_glob = trajectory_file_glob.replace( "epoch_*", "epoch_%06d" % epoch) trajectory_files = gfile.glob( os.path.join(trajectory_dir, trajectory_file_glob)) if not trajectory_files: return None # We read and load all the files, revisit if this becomes a problem. trajectories_buffer = [] completed_trajectories_buffer = [] for trajectory_file in trajectory_files: with gfile.GFile(trajectory_file, "rb") as f: list_trajectories = pickle.load(f) assert isinstance(list_trajectories, list) if not list_trajectories: continue assert isinstance(list_trajectories[0], Trajectory) for trajectory in list_trajectories: if trajectory.done: completed_trajectories_buffer.append(trajectory) else: trajectories_buffer.append(trajectory) if not trajectories_buffer and not completed_trajectories_buffer: return None # Randomly sample `n_trajectories` if needed. n_trajectories = None if not n_trajectories else int(n_trajectories) if n_trajectories and n_trajectories > 0: trajectories_buffer = list( np.random.choice(trajectories_buffer, int(trajectories_buffer))) # Construct and return a new BatchTrajectory object. return BatchTrajectory( batch_size=len(trajectories_buffer), trajectories=trajectories_buffer, completed_trajectories=completed_trajectories_buffer)
def _get_tfrecords(self, name): paths = self.params.data_dir.split(':') data_dir = None for path in paths: if gfile.exists(join(path, name)): data_dir = path break assert data_dir is not None, "data_dir not found" paths = list( map(lambda x: join(data_dir, name, x), self.params.data_pattern.split(','))) files = gfile.glob(paths) if not files: raise IOError("Unable to find files. data_pattern='{}'.".format( self.params.data_pattern)) logging.info("Number of TFRecord files: {}.".format(len(files))) return files
def maybe_restore_opt_state(output_dir, policy_and_value_opt_state, policy_and_value_state): """Maybe restore the optimization state from the checkpoint dir. Optimization state includes parameters and optimizer slots. Args: output_dir: Directory where saved model checkpoints are stored. policy_and_value_opt_state: Default optimization state, returned if model isn't found. policy_and_value_state: state of the policy and value network. Returns: tuple (restored (bool), opt_state, state, epoch (int), opt_step (int)) where epoch is the epoch from which we restored the optimization state, 0 is restored = False, and opt_step is the total optimization step (sum of all optimization steps made up to the current epoch). """ restored = False epoch = 0 total_opt_step = 0 model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl")) for model_file in reversed(sorted(model_files)): logging.info("Trying to restore model from %s", model_file) try: with gfile.GFile(model_file, "rb") as f: policy_and_value_opt_state, policy_and_value_state, total_opt_step = ( pickle.load(f)) model_file_basename = os.path.basename( model_file) # model-??????.pkl restored = True epoch = int(filter(str.isdigit, model_file_basename)) break except EOFError as e: logging.error("Unable to load model from: %s with %s", model_file, e) # Try an older version. continue return ( restored, policy_and_value_opt_state, policy_and_value_state, epoch, total_opt_step, )
def latest_checkpoint(ckpt_dir, prefix='checkpoint_'): """Retrieve the path of the latest checkpoint in a directory. Args: ckpt_dir: str: directory of checkpoints to restore from. prefix: str: name prefix of checkpoint files. Returns: The latest checkpoint path or None if no checkpoints were found. """ glob_path = os.path.join(ckpt_dir, f'{prefix}*') checkpoint_files = natural_sort(gfile.glob(glob_path)) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path] if checkpoint_files: return checkpoint_files[-1] else: return None
def to_dataframe(self, split: str, max_rows: int = 100) -> pd.DataFrame: """Returns dataframe representation of the artifact. Args: split: The name of the datasplit to be returned. max_rows: The maximum number of rows to be returned. If None, all rows in the split will be returned. """ self._validate_payload() if max_rows and max_rows < 0: raise ValueError('`max_rows` must not be negative. Got: %d' % max_rows) filepaths = gfile.glob(os.path.join(self.uri, split, '*')) ds = tf.data.TFRecordDataset(filepaths, compression_type='GZIP').take(max_rows) return self._load_table(ds)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') atomic_writer = smu_writer_lib.AtomicInputWriter() file_count = 0 conformer_count = 0 mismatches = 0 for filepath in gfile.glob(FLAGS.input_glob): logging.info('Processing file %s', filepath) file_count += 1 smu_parser = smu_parser_lib.SmuParser(filepath) for conformer, _ in smu_parser.process_stage2(): conformer_count += 1 actual_contents = atomic_writer.process(conformer) expected_fn = atomic_writer.get_filename_for_atomic_input( conformer) with gfile.GFile(os.path.join(FLAGS.atomic_input_dir, expected_fn)) as expected_f: expected_contents = expected_f.readlines() try: smu_writer_lib.check_dat_formats_match( expected_contents, actual_contents.splitlines()) except smu_writer_lib.DatFormatMismatchError as e: mismatches += 1 print(e) if FLAGS.output_dir: with gfile.GFile( os.path.join( FLAGS.output_dir, atomic_writer.get_filename_for_atomic_input( conformer)), 'w') as f: f.write(actual_contents) status_str = ('COMPLETE: Read %d files, %d conformers, %d mismatches\n' % (file_count, conformer_count, mismatches)) logging.info(status_str) print(status_str)
def save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str: path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. Returns: Filename of saved checkpoint. """ # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path) logging.info('Saved checkpoint at %s', ckpt_path) # Remove old checkpoint files. base_path = os.path.join(ckpt_dir, f'{prefix}') checkpoint_files = natural_sort(gfile.glob(base_path + '*')) if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def tensorboard_event_to_dataframe(path: str) -> pd.DataFrame: """Helper to get events written by tests. Args: path: Path where the tensorboard records were saved. Returns: The metric saved by tensorboard, as a dataframe. """ records = [] all_tb_path = gfile.glob(os.path.join(path, 'events.*.v2')) for tb_event_path in all_tb_path: for e in tf.compat.v1.train.summary_iterator(tb_event_path): if e.step: for v in e.summary.value: records.append(dict( step=e.step, metric=v.tag, value=float(tf.make_ndarray(v.tensor)))) df = pd.DataFrame.from_records(records) return df
def _load_backgrounds(self, backgrounds_dir): """Loads backgrounds from a directory. Args: backgrounds_dir: path to directory containing foregrounds. Dir of the form `backrounds_dir`/$BACKGROUND_TYPE/$FILE_NAME. Produces: self.bgs: a list of the form [bg0, bg1, ...] where the backgrounds are `PIL.Image.Image`s. self.num_bgs: int, number of backgrounds. """ if not gfile.exists(backgrounds_dir): raise ValueError( f'Backgrounds directory {backgrounds_dir} does not exist.') bg_fnames = gfile.glob(path.join(backgrounds_dir, '*')) self.bgs = self._thread_pool.map(load_image, bg_fnames) self.bgs = self._thread_pool.map(self._preprocess_background, self.bgs) self.num_bgs = len(self.bgs) print('Backgrounds loaded.')
def maybe_restore_params(output_dir, policy_and_value_net_params): """Maybe restore the params from the checkpoint dir. Args: output_dir: Directory where saved model checkpoints are stored. policy_and_value_net_params: Default params, returned if model is'nt found. Returns: triple (restore (bool), params, iter(int)) where iter is the epoch from which we restored the params, 0 is restore = False. """ model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl")) if not model_files: return False, policy_and_value_net_params, 0 model_file = sorted(model_files)[-1] model_file_basename = os.path.basename(model_file) # model-??????.pkl i = int(filter(str.isdigit, model_file_basename)) with gfile.GFile(model_file, "rb") as f: policy_and_value_net_params = pickle.load(f) return True, policy_and_value_net_params, i
def test_whole_pipeline(self): test_subdirectory = self.create_tempdir() output_stem = os.path.join(test_subdirectory, 'testout') input_stage1_dat_glob = os.path.join(TESTDATA_PATH, 'pipeline_input_stage1.dat') input_stage2_dat_glob = os.path.join(TESTDATA_PATH, 'pipeline_input_stage2.dat') input_equivalent_glob = os.path.join(TESTDATA_PATH, 'pipeline_equivalent.dat') input_bond_topology_csv = os.path.join(TESTDATA_PATH, 'pipeline_bond_topology.csv') with flagsaver.flagsaver( input_stage1_dat_glob=input_stage1_dat_glob, input_stage2_dat_glob=input_stage2_dat_glob, input_equivalent_glob=input_equivalent_glob, input_bond_topology_csv=input_bond_topology_csv, output_stem=output_stem, output_shards=1): # If you have custom beam options, add them here. beam_options = None with beam.Pipeline(beam_options) as root: pipeline.pipeline(root) metrics = root.result.metrics().query() counters_dict = { m.key.metric.name: m.committed for m in metrics['counters'] } self.assertEqual(counters_dict['attempted_topology_matches'], 3) # Conformer 620517 will not match because bond lengths are not extracted # from conformers with serious errors like this. self.assertEqual(counters_dict['no_topology_matches'], 1) self.assertNotIn('topology_match_smiles_failure', counters_dict) logging.info( 'Files in output: %s', '\n'.join(gfile.glob(os.path.join(test_subdirectory, '*')))) for stage in ['stage1', 'stage2']: self.assertTrue( gfile.exists(output_stem + '_' + stage + '_original_known_error-00000-of-00001.dat')) self.assertTrue( gfile.exists(output_stem + '_' + stage + '_original_unknown_error-00000-of-00001.dat')) self.assertTrue( gfile.exists(output_stem + '_' + stage + '_mismatched_original-00000-of-00001.dat')) self.assertTrue( gfile.exists(output_stem + '_' + stage + '_mismatched_regen-00000-of-00001.dat')) # Check the merge conflicts file with gfile.GFile(output_stem + '_conflicts-00000-of-00001.csv') as f: conflicts_lines = f.readlines() self.assertIn('conformer_id,', conflicts_lines[0]) self.assertEqual( conflicts_lines[1], '618451001,1,1,1,1,' '-406.51179,9.999999,-406.522079,9.999999,True,True,' '-406.51179,0.052254,-406.522079,2.5e-05,True,True\n') # Check a couple of the stats. with gfile.GFile(output_stem + '_stats-00000-of-00001.csv') as f: stats_lines = f.readlines() self.assertIn('errors.status,0,2\n', stats_lines) self.assertIn('errors.warn_t1,0,4\n', stats_lines) self.assertIn('fate,FATE_SUCCESS,2\n', stats_lines) self.assertIn('fate,FATE_DUPLICATE_DIFFERENT_TOPOLOGY,1\n', stats_lines) self.assertIn('num_initial_geometries,1,4\n', stats_lines) self.assertIn('num_duplicates,1,1\n', stats_lines) self.assertIn('zero_field,single_point_energy_pbe0d3_6_311gd,1\n', stats_lines) # Check the smiles comparison output with gfile.GFile(output_stem + '_smiles_compare-00000-of-00001.csv') as f: smiles_lines = f.readlines() self.assertIn( '620517002,MISMATCH,NotAValidSmilesString,' '[H]C1=C2OC2=C(F)O1,FC1=C2OC2=CO1\n', smiles_lines) # Make sure that a bond topology with a matching smiles doesn't show for line in smiles_lines: self.assertNotIn('618451001', line) # Check the bond topology summary with gfile.GFile(output_stem + '_bt_summary-00000-of-00001.csv') as f: bt_summary_lines = f.readlines() # Check part of the header line self.assertIn('bt_id', bt_summary_lines[0]) self.assertIn('count_attempted_conformers', bt_summary_lines[0]) # This is the bond topology that has no conformer self.assertIn('10,0,0,0,0,0,0,0,0,0,0,0,0,0\n', bt_summary_lines) # This is a bond topology with 1 conformer self.assertIn('620517,1,0,0,0,1,0,1,0,0,0,0,0,0\n', bt_summary_lines) # This is a bond topology with 2 conformers self.assertIn('618451,2,0,0,0,2,0,0,0,2,0,0,0,0\n', bt_summary_lines) # Check the bond lengths file with gfile.GFile(output_stem + '_bond_lengths.csv') as f: bond_length_lines = f.readlines() self.assertEqual( 'atom_char_0,atom_char_1,bond_type,length_str,count\n', bond_length_lines[0]) self.assertIn('c,c,2,1.336,1\n', bond_length_lines) self.assertIn('c,o,1,1.422,2\n', bond_length_lines) # For the gzip files below, we check >100 because even an empty gzip file # has non-zero length. 100 is kind of arbitrary to be bigger than the # expected header of 20. # Check that the generated TFRecord files contain some expected outputs standard_dataset = tf.data.TFRecordDataset( output_stem + '_standard_tfrecord-00000-of-00001') standard_output = [ dataset_pb2.Conformer.FromString(raw) for raw in standard_dataset.as_numpy_iterator() ] self.assertCountEqual([c.conformer_id for c in standard_output], [618451001, 618451123]) # Check that fields are filtered the way we expect self.assertFalse( standard_output[0].properties.HasField('compute_cluster_info')) self.assertFalse( standard_output[0].properties.HasField('homo_pbe0_aug_pc_1')) self.assertTrue( standard_output[0].properties.HasField('rotational_constants')) complete_dataset = tf.data.TFRecordDataset( output_stem + '_complete_tfrecord-00000-of-00001') complete_output = [ dataset_pb2.Conformer.FromString(raw) for raw in complete_dataset.as_numpy_iterator() ] self.assertCountEqual([c.conformer_id for c in complete_output], [618451001, 618451123, 620517002, 79593005]) # Check that fields are filtered the way we expect # The DirectRunner randomizes the order of output so we need to make sure # that we get a full record. complete_entry = [ c for c in complete_output if c.conformer_id == 618451001 ][0] self.assertFalse( complete_entry.properties.HasField('compute_cluster_info')) self.assertTrue( complete_entry.properties.HasField('homo_pbe0_aug_pc_1')) self.assertTrue( complete_entry.properties.HasField('rotational_constants')) complete_entry_for_smiles = [ c for c in complete_output if c.conformer_id == 620517002 ][0] self.assertEqual(complete_entry_for_smiles.properties.smiles_openbabel, 'NotAValidSmilesString')
def get_policy_model_files(output_dir): return list( reversed( sorted(gfile.glob(os.path.join(output_dir, "model-??????.pkl")))))