コード例 #1
0
  def get_latest_from_index(self, ckpt_dir):
    index = files_with_extension(ckpt_dir, "index")[0]
    with open(index, "r") as f:
      cfg_dict = json.load(f)

    self.save_counter = cfg_dict["save_counter"]
    return cfg_dict["latest"]
コード例 #2
0
  def get_latest_from_index(self, ckpt_dir):
    """
    Args:
      ckpt_dir: checkpoint directory

    Returns:
      a dict cfg_dict such that cfg_dict["latest"] is the path to the latest checkpoint
    """
    index = files_with_extension(ckpt_dir, "index")[0]
    with open(index, "r") as f:
      cfg_dict = json.load(f)
    return cfg_dict["latest"]
コード例 #3
0
def create_datasets(data_dir: str, dest_dir: str):
    try:
        assert os.path.exists(data_dir)
    except AssertionError:
        raise Exception(
            f"[create_datasets] ERROR: DATA_DIR {data_dir} MUST EXIST")

    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir)

    dataset_creators = [
        ProofStepClassificationDatasetCreator(
            open(os.path.join(dest_dir, "proof_step_classification.json"),
                 "w")),
        PremiseClassificationDatasetCreator(
            open(os.path.join(dest_dir, "premise_classification.json"), "w")),
        TheoremNamePredictionDatasetCreator(
            open(os.path.join(dest_dir, "theorem_name_prediction.json"), "w")),
        NextLemmaPredictionDatasetCreator(
            open(os.path.join(dest_dir, "next_lemma_prediction.json"), "w")),
        ProofTermPredictionDatasetCreator(
            open(os.path.join(dest_dir, "proof_term_prediction.json"), "w")),
        SkipProofDatasetCreator(
            open(os.path.join(dest_dir, "skip_proof.json"), "w")),
        TypePredictionDatasetCreator(
            open(os.path.join(dest_dir, "type_prediction.json"), "w")),
        TSElabDatasetCreator(open(os.path.join(dest_dir, "ts_elab.json"),
                                  "w")),
        ProofTermElabDatasetCreator(
            open(os.path.join(dest_dir, "proof_term_elab.json"), "w")),
        ResultElabDatasetCreator(
            open(os.path.join(dest_dir, "result_elab.json"), "w")),
    ]

    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir)

    json_files = files_with_extension(data_dir, "json")
    print("JSON FILES: ", json_files)
    for json_file in tqdm(json_files):
        line = 0
        try:
            with open(json_file, "r") as json_file_handle:
                for json_line in json_file_handle:
                    line += 1
                    try:
                        dp = json.loads(json_line)
                        for dc in dataset_creators:
                            dc.process_dp(dp)
                    except Exception as e:
                        print(f"BAD LINE IN FILE: {json_file} EXCEPTION: {e}")
        except Exception as e:
            print(f"BAD FILE: {json_file} LINE: {line}: EXCEPTION: {e}")
コード例 #4
0
 def update_index(self, ckpt_path):
   """
   Dump a JSON to a `.index` file, pointing to the most recent checkpoint.
   """
   ckpt_dir = os.path.dirname(ckpt_path)
   index_files = files_with_extension(ckpt_dir, "index")
   if len(index_files) == 0:
     index = os.path.join(ckpt_dir, "latest.index")
   else:
     assert len(index_files) == 1
     index = index_files[0]
   with open(index, "w") as f:
     cfg_dict = {"latest":ckpt_path}
     f.write(json.dumps(cfg_dict, indent=2))
コード例 #5
0
 def update_index(self, ckpt_path, save_counter):
   ckpt_dir = os.path.dirname(ckpt_path)
   index_files = files_with_extension(ckpt_dir, "index")
   if len(index_files) == 0:
     index = os.path.join(ckpt_dir, "latest.index")
   else:
     assert len(index_files) == 1
     index = index_files[0]
   with open(index, "w") as f:
     cfg_dict = {"latest":ckpt_path, "save_counter":save_counter}
     f.write(json.dumps(cfg_dict, indent=2))
   try:
     os.remove(os.path.join(ckpt_dir, f"ckpt_{save_counter-self.max_to_keep}.pth"))
   except:
     pass
コード例 #6
0
 def __init__(self, data_dir):
     self.data_dir = data_dir
     self.files = util.files_with_extension(self.data_dir, "cnf")
     self.file_index = 0