def convert_file(fname): new_name = os.path.splitext(fname)[0] + ".tfrecord" log.info("Processing file: %s", fname) record_iter = utils.load_json_data(fname) log.info(" saving to %s", new_name) with tf.python_io.TFRecordWriter(new_name) as writer: for record in tqdm(record_iter): example = convert_to_example(record) writer.write(example.SerializeToString())
def count_labels(fname: str) -> Type[collections.Counter]: """Count labels across all targets in a file of edge probing examples.""" label_ctr = collections.Counter() record_iter = utils.load_json_data(fname) for record in tqdm(record_iter): for target in record["targets"]: label = target["label"] if isinstance(label, str): label = [label] label_ctr.update(label) return label_ctr
def from_run(cls, run_dir: str, task_name: str, split_name: str): # Load vocabulary exp_dir = os.path.dirname(run_dir.rstrip("/")) vocab_path = os.path.join(exp_dir, "vocab") log.info("Loading vocabulary from %s" % vocab_path) vocab = Vocabulary.from_files(vocab_path) label_namespace = f"{task_name}_labels" # Load predictions preds_file = os.path.join(run_dir, f"{task_name}_{split_name}.json") log.info("Loading predictions from %s" % preds_file) return cls(vocab, utils.load_json_data(preds_file), label_namespace=label_namespace)
def _stream_records(cls, filename): skip_ctr = 0 total_ctr = 0 for record in utils.load_json_data(filename): total_ctr += 1 # Skip records with empty targets. # TODO(ian): don't do this if generating negatives! if not record.get("targets", None): skip_ctr += 1 continue yield record log.info( "Read=%d, Skip=%d, Total=%d from %s", total_ctr - skip_ctr, skip_ctr, total_ctr, filename, )
def split_file(fname): dirname, base = os.path.split(fname) pos_dir = os.path.join(dirname, "pos") os.makedirs(pos_dir, exist_ok=True) new_pos_name = os.path.join(pos_dir, base) non_dir = os.path.join(dirname, "nonterminal") os.makedirs(non_dir, exist_ok=True) new_non_name = os.path.join(non_dir, base) log.info("Processing file: %s", fname) record_iter = list(utils.load_json_data(fname)) log.info(" saving to %s and %s", new_pos_name, new_non_name) pos_fd = open(new_pos_name, "w") non_fd = open(new_non_name, "w") for record in tqdm(record_iter): pos_record, non_record = split_record(record) pos_fd.write(json.dumps(pos_record)) pos_fd.write("\n") non_fd.write(json.dumps(non_record)) non_fd.write("\n")