def load_from_directory(trajectory_dir, epoch=None, n_trajectories=None): """Load trajectories from specified dir and epoch.""" trajectory_file_glob = TRAJECTORY_FILE_GLOB # If there is a desired epoch, modify the glob to get that instead. if epoch: trajectory_file_glob = trajectory_file_glob.replace( "epoch_*", "epoch_%06d" % epoch) trajectory_files = gfile.glob( os.path.join(trajectory_dir, trajectory_file_glob)) if not trajectory_files: return None # We read and load all the files, revisit if this becomes a problem. trajectories_buffer = [] completed_trajectories_buffer = [] for trajectory_file in trajectory_files: with gfile.GFile(trajectory_file, "rb") as f: list_trajectories = pickle.load(f) assert isinstance(list_trajectories, list) if not list_trajectories: continue assert isinstance(list_trajectories[0], Trajectory) for trajectory in list_trajectories: if trajectory.done: completed_trajectories_buffer.append(trajectory) else: trajectories_buffer.append(trajectory) if not trajectories_buffer and not completed_trajectories_buffer: return None # Randomly sample `n_trajectories` if needed. n_trajectories = None if not n_trajectories else int(n_trajectories) if n_trajectories and n_trajectories > 0: trajectories_buffer = list( np.random.choice(trajectories_buffer, int(trajectories_buffer))) # Construct and return a new BatchTrajectory object. return BatchTrajectory( batch_size=len(trajectories_buffer), trajectories=trajectories_buffer, completed_trajectories=completed_trajectories_buffer)
def maybe_restore_opt_state(output_dir, policy_and_value_opt_state, policy_and_value_state): """Maybe restore the optimization state from the checkpoint dir. Optimization state includes parameters and optimizer slots. Args: output_dir: Directory where saved model checkpoints are stored. policy_and_value_opt_state: Default optimization state, returned if model isn't found. policy_and_value_state: state of the policy and value network. Returns: tuple (restored (bool), opt_state, state, epoch (int), opt_step (int)) where epoch is the epoch from which we restored the optimization state, 0 is restored = False, and opt_step is the total optimization step (sum of all optimization steps made up to the current epoch). """ restored = False epoch = 0 total_opt_step = 0 model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl")) for model_file in reversed(sorted(model_files)): logging.info("Trying to restore model from %s", model_file) try: with gfile.GFile(model_file, "rb") as f: policy_and_value_opt_state, policy_and_value_state, total_opt_step = ( pickle.load(f)) model_file_basename = os.path.basename( model_file) # model-??????.pkl restored = True epoch = int(filter(str.isdigit, model_file_basename)) break except EOFError as e: logging.error("Unable to load model from: %s with %s", model_file, e) # Try an older version. continue return ( restored, policy_and_value_opt_state, policy_and_value_state, epoch, total_opt_step, )
def save_opt_state(output_dir, policy_and_value_opt_state, policy_and_value_state, epoch, total_opt_step): """Saves the policy and value network optimization state etc.""" pkl_module = utils.get_pickle_module() old_model_files = get_policy_model_files(output_dir) params_file = os.path.join(output_dir, "model-%06d.pkl" % epoch) with gfile.GFile(params_file, "wb") as f: pkl_module.dump( (policy_and_value_opt_state, policy_and_value_state, total_opt_step), f) # Remove the old model files, leave the latest one (it might be in the # process of getting read async) -- this will get cleaned up later. for path in old_model_files[1:]: if path != params_file: gfile.remove(path)
def main(argv): del argv # Unused. smu_proto = dataset_pb2.MultipleMolecules() with gfile.GFile(FLAGS.input_file) as f: raw_proto = f.read() text_format.Parse(raw_proto, smu_proto) smu_writer = SmuWriter(FLAGS.annotate) contents = ''.join( smu_writer.process_stage2_proto(molecule) for molecule in smu_proto.molecules) if FLAGS.output_file: logging.info('Writing smu7 molecules to .dat file %s.', FLAGS.output_file) with open(FLAGS.output_file, 'w') as f: f.write(contents) else: print(contents, end='')
def process( self, question_answer: QuestionAnswer ) -> Generator[QuestionAnswerEvidence, None, None]: for info in question_answer.evidence_info: if info.source == 'EntityPages': evidence_path = os.path.join(self._wikipedia_dir, info.id) elif info.source == 'SearchResult': evidence_path = os.path.join(self._web_dir, info.id) else: raise ValueError(f'Unknown evidence source: {info.source}.') with gfile.GFile(evidence_path, 'rb') as f: text = f.read().decode('utf-8') metrics.Metrics.counter('_', 'documents').inc() yield QuestionAnswerEvidence(question=question_answer.question, evidence=Evidence(info=info, text=text), answer=question_answer.answer)
def _write_episodes(self): """Write recorded episodes to pickle.""" # Inclusive first and last. episode_idx = self._get_episode_idx() if self._eval_mode: counts = self._counter.get_counts() train_steps = counts['steps'] if 'steps' in counts else 0 filename = f'evaluation_{train_steps}.pkl' else: first_ep = episode_idx - len(self._episodes_to_record) + 1 last_ep = episode_idx filename = f'episodes_{first_ep}-{last_ep}.pkl' log_path = os.path.join(self._logdir, filename) print('Episode', episode_idx, ': flushing to', log_path) with gfile.GFile(log_path, 'wb') as f: pickle.dump(self._episodes_to_record, f) self._episodes_to_record = []
def set_eval_paths(ckpt_dir, ckpt, custom_eval_id): """Set paths for evaluation and TensorBoard summaries.""" eval_path, summary_key, best_epoch = None, None, None if ckpt is not None: best_epoch_path = (ckpt.replace('ckpt_best_of_', 'best_epoch_of_').replace( 'ckpt', 'best_epoch')) best_epoch_path = os.path.join(ckpt_dir, best_epoch_path) if gfile.exists(best_epoch_path): with gfile.GFile(best_epoch_path) as f: best_epoch = int(f.read()) else: best_epoch = int(ckpt.replace('ckpt_', '')) print('best epoch:', best_epoch) eval_path = set_eval_path(ckpt_dir, custom_eval_id, ckpt) summary_key = set_summary_key(custom_eval_id, ckpt) return eval_path, summary_key, best_epoch
def load_trajectories(trajectory_dir, eval_frac): """Loads trajectories from a possibly nested directory of pickles.""" pkl_module = utils.get_pickle_module() train_trajectories = [] eval_trajectories = [] # Search the entire directory subtree for trajectories. for (subdir, _, filenames) in gfile.walk(trajectory_dir): for filename in filenames: shard_path = os.path.join(subdir, filename) with gfile.GFile(shard_path, "rb") as f: trajectories = pkl_module.load(f) pivot = int(len(trajectories) * (1 - eval_frac)) train_trajectories.extend(trajectories[:pivot]) eval_trajectories.extend(trajectories[pivot:]) assert train_trajectories, "Haven't found any training data." assert eval_trajectories, "Haven't found any evaluation data." return (train_trajectories, eval_trajectories)
def save_data(df, path, out_data, mode, bucket): if mode == 'cloud': out_csv_gcs = f'{bucket}/{path}/{out_data}' logging.info(f'Writing {out_csv_gcs} file...') with gfile.GFile(name=out_csv_gcs, mode='w') as file: df.to_csv(file, index=False) logging.info(f'{out_csv_gcs} successfully loaded!') return out_csv_gcs else: p = Path(path) if not p.exists(): os.mkdir(path) out_csv = f'{path}/{out_data}' logging.info(f'Writing {out_csv} file...') df.to_csv(out_csv, index=False) logging.info(f'{out_csv} successfully loaded!') return out_csv
def log_pytree_shape_and_statistics(pytree, json_path=None): """Logs the shape and norm of every array in the pytree.""" if not pytree: absl_logging.info('Empty pytree') return if json_path: shape_dict = jax.tree_map(lambda x: x.shape, pytree).pretty_repr() with gfile.GFile(json_path, 'w') as json_file: json_file.write(shape_dict) absl_logging.info('Printing model param shapes.') shape_dict = jax.tree_map(_summary_str, pytree) absl_logging.info(shape_dict.pretty_repr()) total_params = jax.tree_util.tree_reduce( operator.add, jax.tree_map(lambda x: x.size, pytree)) absl_logging.info('Total params: %d', total_params)
def load_dataset_metadata(metadata_filename): """Helper function to load dataset metadata. Args: metadata_filename: Filename containing dataset metadata. Returns: Padding configuration and edge types for the dataset. """ with gfile.GFile(metadata_filename, "r") as fp: metadata = json.load(fp) edge_types = metadata["edge_types"] padding_config = flax.serialization.from_state_dict( target=jax_util.synthesize_dataclass(graph_bundle.PaddingConfig), state=metadata["padding_config"]) return padding_config, edge_types
def pipeline(root): """Beam pipeline. Args: root: the root of the pipeline. """ _ = ( root | 'CreateTopologies' >> beam.Create( smu_utils_lib.generate_bond_topologies_from_csv( gfile.GFile(FLAGS.input_bond_topology_csv, 'r'))) | 'Reshuffle1' >> beam.Reshuffle() | 'CheckInvariance' >> beam.FlatMap(check_smiles_permutation_invariance) | 'Reshuffle2' >> beam.Reshuffle() | 'CSVFormat' >> beam.Map(lambda vals: ','.join(str(x) for x in vals)) | 'WriteOutput' >> beam.io.WriteToText( FLAGS.output_csv, header='bt_id,smiles0,smiles1', num_shards=1))
def restore_state(output_dir): """Restore State.""" params_file = os.path.join(output_dir, "model.pkl") if not gfile.exists(params_file): return State(step=None, opt_state=None, history=trax_history.History(), model_state=None) with gfile.GFile(params_file, "rb") as f: (opt_state, step, history, model_state) = pickle.load(f) log("Model loaded from %s at step %d" % (params_file, step)) logging.debug("From loaded model : history = %s", history) return State(step=step, opt_state=OptState(*opt_state), history=history, model_state=model_state)
def generate_val_timestep(self): """Generator to iterate over the validation split.""" demo_idx = 0 val_pointer = gfile.GFile(self.path, 'rb') while True: try: if not self.in_memory: demo = pickle.load(val_pointer) if self.demo_in_val(demo_idx): # Preprocess only if (at least some time step) in val split. if self.in_memory: observations = self.observations[demo_idx] actions = self.actions[demo_idx] else: observations = demo['observations'] actions = demo['actions'] if self.max_demo_length is not None: observations = observations[:self.max_demo_length] actions = actions[:self.max_demo_length] if not self.in_memory or not self._decompress_once: # Decompress images of all viewpoints (if applicable) to keep # viewpoint constant across stacked frames. observations = self.decompress_all_images(observations) signals = [None for _ in observations] if FLAGS.clip_actions: actions = list(np.clip(actions, -1, 1)) if self.agent is not None: observations, signals, actions = self.agent.normalize_demo( observations, actions, augment_frames=self.augment_frames, randomize_camera=True) for t in range(len(actions)): if self.split_by_demo: if not self.episode_train_split[demo_idx]: yield observations[t], signals[t], actions[ t], demo_idx, t elif not self.episode_train_split[demo_idx][t]: yield observations[t], signals[t], actions[ t], demo_idx, t demo_idx += 1 if self.max_to_load is not None and demo_idx >= self.max_to_load: return except EOFError: return
def test_save_restore_checkpoints_w_float_steps(self): tmp_dir = self.create_tempdir().full_path test_object0 = { 'a': np.array([0, 0, 0], np.int32), 'b': np.array([0, 0, 0], np.int32) } test_object1 = { 'a': np.array([1, 2, 3], np.int32), 'b': np.array([1, 1, 1], np.int32) } test_object2 = { 'a': np.array([4, 5, 6], np.int32), 'b': np.array([2, 2, 2], np.int32) } # Create leftover temporary checkpoint, which should be ignored. gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') checkpoints.save_checkpoint(tmp_dir, test_object1, 0.0, prefix='test_', keep=1) self.assertIn('test_0.0', os.listdir(tmp_dir)) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object1) checkpoints.save_checkpoint(tmp_dir, test_object1, 2.0, prefix='test_', keep=1) with self.assertRaises(errors.InvalidCheckpointError): checkpoints.save_checkpoint(tmp_dir, test_object2, 1.0, prefix='test_', keep=1) checkpoints.save_checkpoint(tmp_dir, test_object2, 3.0, prefix='test_', keep=2) self.assertIn('test_3.0', os.listdir(tmp_dir)) self.assertIn('test_2.0', os.listdir(tmp_dir)) jtu.check_eq(new_object, test_object1)
def dump_tabular(self): df = pd.DataFrame(self.unsaved, index=self.unsaved[self.index_key]) if self.rank == 0: if self.print_count % 10 == 0: print(df.round(3), flush=True) else: print(df.round(3).to_string().split("\n")[1], flush=True) self.print_count += 1 if self.args.backend == 'local': df.to_csv(self.filename, sep='\t', mode='a', header=self.use_headers) else: with gfile.GFile(self.filename, 'a') as f: df.to_csv(f, sep='\t', mode='a', header=self.use_headers) self.unsaved = {} self.use_headers = False
def load_trainer_state(output_dir): """Returns a TrainerState instance loaded from the given `output_dir`.""" params_file = os.path.join(output_dir, 'model.pkl') if not gfile.exists(params_file): return TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) pkl_module = utils.get_pickle_module() with gfile.GFile(params_file, 'rb') as f: (opt_state, step, history, model_state) = pkl_module.load(f) log('Model loaded from %s at step %d' % (params_file, step)) logging.debug('From loaded model : history = %s', history) return TrainerState(step=step, opt_state=OptState(*opt_state), history=history, model_state=model_state)
def log_pytree_shape_and_statistics(pytree, json_path=None): """Logs the shape and norm of every array in the pytree.""" if not pytree: absl_logging.info('Empty pytree') return if json_path: shape_dict = json.dumps(jax.tree_map(lambda x: x.shape, pytree)) with gfile.GFile(json_path, 'w') as json_file: json_file.write(shape_dict) absl_logging.info('Printing model param shapes.') shape_dict = jax.tree_map(_summary_str, pytree) # We use json.dumps for pretty printing nested dicts. absl_logging.info(json.dumps(shape_dict, sort_keys=True, indent=4)) total_params = jax.tree_util.tree_reduce( operator.add, jax.tree_map(lambda x: x.size, pytree)) absl_logging.info('Total params: %d', total_params)
def add_demos(self, path, max_to_load=None): """Add additional demos to training set without affecting dataset stats.""" # TODO(minttu): not-in-memory option self._stats_fixed = True if self.max_to_load is not None and FLAGS.stats_from_large_dataset: self.observations = self.observations[:self.max_to_load] self.actions = self.actions[:self.max_to_load] length_before = len(self.observations) observations = self.observations actions = self.actions lengths = self.episode_lengths num_new_demos = 0 print('Before adding demos') print(len(self.observations), len(self.actions), len(self.episode_train_split)) with gfile.GFile(path, 'rb') as f: while True: try: demo = pickle.load(f) num_new_demos += 1 obs_demo = demo['observations'] act_demo = demo['actions'] if self.max_demo_length is not None: obs_demo = obs_demo[:self.max_demo_length] act_demo = act_demo[:self.max_demo_length] observations.append(obs_demo) actions.append(np.stack(act_demo)) lengths.append(len(obs_demo)) if self.split_by_demo: self.episode_train_split = np.concatenate( [self.episode_train_split, np.ones(1)]) else: self.episode_train_split = np.concatenate( [self.episode_train_split, np.ones(len(obs_demo))]) if max_to_load is not None and num_new_demos >= max_to_load: break except EOFError: break assert len(self.observations) == length_before + num_new_demos print('Addeed', num_new_demos, 'demos') print(len(self.observations), len(self.actions), len(self.episode_train_split))
def load_ihdp(self): """Loads semi-synthetic data. It updates the object DataSimulation. Args: self Returns: None """ self.data_path = self.param_data['data_path'] + 'IHDP/' # Reference: https://github.com/AMLab-Amsterdam/CEVAE # each iteration, it randomly pick one of the 10 existing repetitions np.random.seed(self.seed) i = np.random.randint(1, 10, 1)[0] path = self.data_path + '/ihdp_npci_' + str(i) + '.csv.txt' with gfile.GFile(path, 'r') as f: data = np.loadtxt(f, delimiter=',') self.outcome, y_cf = data[:, 1][:, np.newaxis], data[:, 2][:, np.newaxis] self.outcome = self.outcome.ravel() self.treatment = data[:, 0].ravel() self.covariates = data[:, 5:] scaler = StandardScaler() self.covariates = scaler.fit_transform(self.covariates) self.sample_size, self.num_covariates = self.covariates.shape self.linear, self.noise = False, False self.var_covariates = None self.treatment_prop = self.treatment.sum()/len(self.treatment) y1, y0 = self.outcome, self.outcome y1 = [ y_cf[j][0] if item == 0 else self.outcome[j] for j, item in enumerate(self.treatment) ] y0 = [ y_cf[j][0] if item == 1 else self.outcome[j] for j, item in enumerate(self.treatment) ] y1 = np.array(y1) y0 = np.array(y0) self.tau = (y1-y0).mean()
def parse_duplicates_file(filename): """Parses duplciate file into a pandas dataframe. The duplciate file supplied by our collaborators (called list.equivalent_{isomers,conformers.dat) is a two column, space separated file of composite names like x07_n4o3h4.091404.073 which we parse the names into columns * nameX: original composiite name from file * stoichX: string for the stoichiometry * btidX: bond topology id * shortconfidX: 3 digit conformer id * confidX: full conformer id that we use (btid * 1000 + shortconfid) (for X = 1 or 2) Args: filename: file to read (usually list.equivalent_isomers.dat) Returns: pd.DataFrame """ with gfile.GFile(filename) as f: df_dups = pd.read_csv(f, delim_whitespace=True, names=['name1', 'name2'], header=None) for idx in ['1', '2']: df_dups = pd.concat([ df_dups, df_dups['name' + idx].str.extract(r'x07_([\w\d]+)\.(\d+).(\d+)').rename( columns={ 0: 'stoich' + idx, 1: 'btid' + idx, 2: 'shortconfid' + idx }) ], axis=1) df_dups['btid' + idx] = df_dups['btid' + idx].astype(int) df_dups['shortconfid' + idx] = df_dups['shortconfid' + idx].astype(int) df_dups['confid' + idx] = (df_dups['btid' + idx] * 1000 + df_dups['shortconfid' + idx]) return df_dups
def generate_examples(file_path: str): """Provides a common generate_examples method for D4RL datasets.""" with gfile.GFile(file_path, 'rb') as f: dataset_file = h5py.File(f, 'r') dataset_dict = {} for k in _get_dataset_keys(dataset_file): try: # first try loading as an array dataset_dict[k] = dataset_file[k][:] except ValueError as e: # try loading as a scalar dataset_dict[k] = dataset_file[k][()] dataset_file.close() if 'timeouts' not in dataset_dict: raise ValueError('Only datasets with explicit timeouts are supported.') done = [ terminal or timeout for (terminal, timeout) in zip(dataset_dict['terminals'], dataset_dict['timeouts']) ] # is_first corresponds to the done flag delayed by one step. dataset_dict['is_first'] = [True] + done[:-1] # TODO(sabela): Add extra keys for metadata (qpos, qval, goal) that is only # present in some datasets. dataset_dict = { 'observation': dataset_dict['observations'], 'action': dataset_dict['actions'], 'reward': dataset_dict['rewards'], 'discount': np.ones_like(dataset_dict['rewards']), 'is_terminal': dataset_dict['terminals'], 'is_first': dataset_dict['is_first'], } num_steps = len(dataset_dict['is_first']) prev = 0 counter = 0 for pos in range(num_steps): if dataset_dict['is_first'][pos] and pos > prev: yield counter, _get_episode(dataset_dict, prev, pos) prev = pos counter += 1 if prev < num_steps: yield counter, _get_episode(dataset_dict, prev, num_steps)
def gen_csv_from_annotations( input_dir: str, output_file=constants.DEFAULT_CSV_FILENAME, out_path_prefix='', dataset_type=constants.DEFAULT_DATASET_TYPE): """Generates AutoML dataset CSV from annotation files. Args: input_dir: Directory of annotation files. output_file: Output CSV filename. out_path_prefix: Filepath prefix to prepend to the image files. e.g. src_image_filename = '/tmp/path/to/image.jpg' out_path_prefix = 'gs://bucket/images' output_image_filename = 'gs://bucket/images/image.jpg' dataset_type: Dataset type (TRAIN, VAL, TEST, UNSPECIFIED) to use for all the parsed images. """ if not gfile.exists(input_dir): raise ValueError('Input directory not found.') with gfile.GFile(os.path.expanduser(output_file), 'w') as outf: writer = csv.writer(outf, delimiter=',') for filename in gfile.listdir(os.path.expanduser(input_dir)): filepath = os.path.join(input_dir, filename) image_filename, boxes = annotation.read(filepath) out_image_filename = os.path.join(out_path_prefix, image_filename) for b in boxes: row = [ dataset_type, out_image_filename, b.label, b.xmin, b.ymin, '', '', b.xmax, b.ymax, '', '', ] writer.writerow(row)
def smiles_to_id(bond_topology_filename): """DoFn for creating the smiles to id mapping. Reads the same merged_bond_topology file as bond_topology_summaries_from_csv and output. We could of course produce them both at the same time, but this is simpler. Args: bond_topology_filename: see FLAGS.input_bond_topology_csv Yields: smiles, bond_topology_id """ with gfile.GFile(bond_topology_filename, 'r') as infile: reader = csv.reader(iter(infile)) next(reader) # skip the header line for row in reader: bt_id, _, _, _, _, smiles = row yield smiles, int(bt_id)
def maybe_restore_opt_state(output_dir, policy_and_value_opt_state=None, policy_and_value_state=None): """Maybe restore the optimization state from the checkpoint dir. Optimization state includes parameters and optimizer slots. Args: output_dir: Directory where saved model checkpoints are stored. policy_and_value_opt_state: Default optimization state, returned if model isn't found. policy_and_value_state: state of the policy and value network. Returns: tuple (opt_state, state, epoch (int), opt_step (int)) where epoch is the epoch from which we restored the optimization state, 0 if no checkpoint was found, and opt_step is the total optimization step (sum of all optimization steps made up to the current epoch). """ pkl_module = utils.get_pickle_module() epoch = 0 total_opt_step = 0 history = trax_history.History() for model_file in get_policy_model_files(output_dir): logging.info('Trying to restore model from %s', model_file) try: with gfile.GFile(model_file, 'rb') as f: (policy_and_value_opt_state, policy_and_value_state, total_opt_step, history) = pkl_module.load(f) epoch = get_epoch_from_policy_model_file(model_file) break except EOFError as e: logging.error('Unable to load model from: %s with %s', model_file, e) # Try an older version. continue return ( policy_and_value_opt_state, policy_and_value_state, epoch, total_opt_step, history, )
def save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str: path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. Returns: Filename of saved checkpoint. """ # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path) logging.info('Saved checkpoint at %s', ckpt_path) # Remove old checkpoint files. base_path = os.path.join(ckpt_dir, f'{prefix}') checkpoint_files = natural_sort(gfile.glob(base_path + '*')) if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def from_file(cls, filename, right_tail_mass): """Load the distribution from a file. The file should be a CSV with two (unnamed) columns. The first is a length bucket and the second is a count. Each line represents the number of observed lengths between the length on tha line and the next line. The buckets should be equally spaced.. Args: filename: file to read right_tail_mass: probability mass added past the right side of the buckets (see class documentation) Returns: EmpiricalLengthDistribution """ with gfile.GFile(filename) as f: df = pd.read_csv(f, header=None, names=['length', 'count'], dtype=float) return EmpiricalLengthDistribution(df, right_tail_mass)
def keep_latest_trajectories(self, demos_file, num_to_keep): # Keep num_to_keep shortest trajectories in the dataset at demos_file. print(demos_file) all_demos_file = (demos_file.replace(f'e{num_to_keep}', '').replace('.pkl', 'all.pkl')) print(all_demos_file) gfile.Rename(demos_file, all_demos_file) last_demos = [] with gfile.GFile(all_demos_file, 'rb') as f: while True: try: demo = pickle.load(f) last_demos.append(demo) last_demos = last_demos[:num_to_keep] except EOFError: break new_demo_writer = pickle_dataset.DemoWriter(demos_file) for demo in last_demos: new_demo_writer.write_episode(demo['observations'], demo['actions'])
def main(unused_argv): config = json_utils.json_file_to_dict(FLAGS.config) wrapper = inference_utils.get_inference_wrapper(config, FLAGS.rules, FLAGS.target_grammar, FLAGS.verbose) _ = inference_utils.get_checkpoint(wrapper, FLAGS.model_dir, FLAGS.checkpoint) examples = tsv_utils.read_tsv(FLAGS.input) num_predictions_match = 0 predictions = [] for idx, example in enumerate(examples): if FLAGS.offset and idx < FLAGS.offset: continue if FLAGS.limit and idx >= FLAGS.limit: break if FLAGS.verbose: print("Processing example %s: (%s, %s)" % (idx, example[0], example[1])) source = example[0] original_target = example[1] predicted_target = inference_parser.get_top_output(source, wrapper) if FLAGS.verbose: print("predicted_target: %s" % predicted_target) if predicted_target == original_target: num_predictions_match += 1 else: if FLAGS.verbose: print("predictions do not match.") predictions.append(predicted_target) print("num_predictions_match: %s" % num_predictions_match) with gfile.GFile(FLAGS.output, "w") as txt_file: for prediction in predictions: txt_file.write("%s\n" % prediction)
def maybe_restore_params(output_dir, policy_and_value_net_params): """Maybe restore the params from the checkpoint dir. Args: output_dir: Directory where saved model checkpoints are stored. policy_and_value_net_params: Default params, returned if model is'nt found. Returns: triple (restore (bool), params, iter(int)) where iter is the epoch from which we restored the params, 0 is restore = False. """ model_files = gfile.glob(os.path.join(output_dir, "model-??????.pkl")) if not model_files: return False, policy_and_value_net_params, 0 model_file = sorted(model_files)[-1] model_file_basename = os.path.basename(model_file) # model-??????.pkl i = int(filter(str.isdigit, model_file_basename)) with gfile.GFile(model_file, "rb") as f: policy_and_value_net_params = pickle.load(f) return True, policy_and_value_net_params, i