def prepare_code(self): IOUtils.rm_dir(self.code_dir) IOUtils.mk_dir(self.code_dir.parent) with IOUtils.cd(self.code_dir.parent): BashUtils.run(f"git clone {self.REPO_URL} {self.code_dir.name}", expected_return_code=0) # end with with IOUtils.cd(self.code_dir): BashUtils.run(f"git checkout {self.REPO_SHA}", expected_return_code=0) # end with return
def prepare_code(self): IOUtils.rm_dir(self.code_dir) IOUtils.mk_dir(self.code_dir.parent) with IOUtils.cd(self.code_dir.parent): BashUtils.run(f"git clone {self.REPO_URL} {self.code_dir.name}", expected_return_code=0) # end with with IOUtils.cd(self.code_dir): BashUtils.run(f"git checkout {self.REPO_SHA}", expected_return_code=0) # end with # copy eval code BashUtils.run(f"cp {Macros.this_dir}/eval/eval_utils.py {self.code_dir}/") return
def download_global_model(self, force_yes: bool = False): """ Downloads a global Roosterize model. """ global_model_dir = RoosterizeDirUtils.get_global_model_dir() if global_model_dir.exists(): ans = self.ask_for_confirmation( f"A Roosterize model already exists at {global_model_dir}. " f"Do you want to delete it and download again?") if force_yes: ans = True if ans != True: return IOUtils.rm_dir(global_model_dir) self.show_message("Downloading Roosterize model...") # Download and unpack temp_model_dir = Path(tempfile.mkdtemp(prefix="roosterize")) urllib.request.urlretrieve(self.model_url, str(temp_model_dir / "model.tgz")) with IOUtils.cd(temp_model_dir): BashUtils.run("tar xzf model.tgz", expected_return_code=0) # Move the stuff to global model place shutil.move(str(Path.cwd() / "model"), global_model_dir) # Delete temp dir IOUtils.rm_dir(temp_model_dir) self.show_message("Finish downloading Roosterize model.")
def suggest_naming(self, file_path: Path, prj_root: Optional[Path] = None): """ Processes a file to get its lemmas and runs the model to get predictions. """ # Figure out which project we're at, and then load configs if prj_root is None: prj_root = RoosterizeDirUtils.auto_infer_project_root(file_path) self.load_configs(prj_root) # Infer SerAPI options serapi_options = self.infer_serapi_options(prj_root) # If user provided compile_cmd, first compile the project if self.compile_cmd is not None: with IOUtils.cd(prj_root): BashUtils.run(self.compile_cmd, expected_return_code=0) # Parse file data = self.parse_file(file_path, prj_root, serapi_options) # Load model self.load_local_model(prj_root) model = self.get_model() # Use the model to make predictions # Temp dirs for processed data and results temp_data_dir = Path(tempfile.mkdtemp(prefix="roosterize")) # Dump lemmas & definitions temp_raw_data_dir = temp_data_dir / "raw" temp_raw_data_dir.mkdir() IOUtils.dump( temp_raw_data_dir / "lemmas.json", IOUtils.jsonfy(data.lemmas), IOUtils.Format.json, ) IOUtils.dump( temp_raw_data_dir / "definitions.json", IOUtils.jsonfy(data.definitions), IOUtils.Format.json, ) # Model-specific process temp_processed_data_dir = temp_data_dir / "processed" temp_processed_data_dir.mkdir() model.process_data_impl(temp_raw_data_dir, temp_processed_data_dir) # Invoke eval candidates_logprobs = model.eval_impl( temp_processed_data_dir, beam_search_size=self.beam_search_size, k=self.k, ) # Save predictions IOUtils.rm_dir(temp_data_dir) # Report predictions self.report_predictions(data, candidates_logprobs) return
def parse_file(self, file_path: Path, prj_root: Path, serapi_options: str): source_code = IOUtils.load(file_path, IOUtils.Format.txt) unicode_offsets = ParserUtils.get_unicode_offsets(source_code) with IOUtils.cd(prj_root): rel_path = file_path.relative_to(prj_root) ast_sexp_str = BashUtils.run( f"sercomp {serapi_options} --mode=sexp -- {rel_path}", expected_return_code=0).stdout tok_sexp_str = BashUtils.run( f"sertok {serapi_options} -- {rel_path}", expected_return_code=0).stdout ast_sexp_list: List[SexpNode] = SexpParser.parse_list(ast_sexp_str) tok_sexp_list: List[SexpNode] = SexpParser.parse_list(tok_sexp_str) doc = CoqParser.parse_document( source_code, ast_sexp_list, tok_sexp_list, unicode_offsets=unicode_offsets, ) doc.file_name = str(rel_path) # Collect lemmas & definitions lemmas: List[Lemma] = DataMiner.collect_lemmas_doc( doc, ast_sexp_list, serapi_options) definitions: List[Definition] = DataMiner.collect_definitions_doc( doc, ast_sexp_list) return ProcessedFile(file_path, source_code, doc, ast_sexp_list, tok_sexp_list, unicode_offsets, lemmas, definitions)
def eval_impl(self, processed_data_dir: Path, model_dir: Path, beam_search_size: int, k: int ) -> List[List[Tuple[str, float]]]: from roosterize.ml.onmt.CustomTranslator import CustomTranslator from onmt.utils.misc import split_corpus from onmt.utils.parse import ArgumentParser from translate import _get_parser as translate_get_parser src_path = processed_data_dir/"src.txt" tgt_path = processed_data_dir/"tgt.txt" best_step = IOUtils.load(model_dir/"best-step.json", IOUtils.Format.json) self.logger.info(f"Taking best step at {best_step}") candidates_logprobs: List[List[Tuple[List[str], float]]] = list() with IOUtils.cd(self.open_nmt_path): parser = translate_get_parser() opt = parser.parse_args( f" -model {model_dir}/models/ckpt_step_{best_step}.pt" f" -src {src_path}" f" -tgt {tgt_path}" ) opt.output = f"{model_dir}/last-pred.txt" opt.beam_size = beam_search_size opt.gpu = 0 if torch.cuda.is_available() else -1 opt.n_best = k opt.block_ngram_repeat = 1 opt.ignore_when_blocking = ["_"] # translate.main ArgumentParser.validate_translate_opts(opt) translator = CustomTranslator.build_translator(opt, report_score=False) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): self.logger.info("Translating shard %d." % i) _, _, candidates_logprobs_shard = translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug ) candidates_logprobs.extend(candidates_logprobs_shard) # end for # end with # Reformat candidates candidates_logprobs: List[List[Tuple[str, float]]] = [[("".join(c), l) for c, l in cl] for cl in candidates_logprobs] return candidates_logprobs
def require_special_repo(cls, directory: Path, branch: str): cls.logger.info(f"Updating {directory} to {branch} branch") if directory.exists(): if not directory.is_dir() or not (directory / ".git").is_dir(): LoggingUtils.log_and_raise( cls.logger, f"Path {directory} already exists but is not a proper git repository!", Exception) # end if with IOUtils.cd(directory): BashUtils.run(f"git pull", expected_return_code=0) # end with else: IOUtils.mk_dir(directory) with IOUtils.cd(directory): BashUtils.run( f"git clone --single-branch -b {branch} -- {cls.get_git_url()} .", expected_return_code=0)
def install_coq_project(cls, project: Project, names_projects: Dict[str, Project]) -> None: """ :requires: the project is cloned and checked-out to the desired version. """ if not project.is_cloned: project.clone() project.checkout(project.data["sha"], is_forced=True) # end if # Check if the project is already compiled confirmation_file = "lpc-installed.txt" confirmation_content = project.revision + " " + BashUtils.run("opam list coq -s", expected_return_code=0).stdout.strip() if (project.checkout_dir/confirmation_file).is_file() and IOUtils.load(project.checkout_dir/confirmation_file, "txt") == confirmation_content: cls.logger.debug(f"Project {project.full_name} already installed") return # end if project.clean() # Install dependencies for dependency in project.data.get("dependencies", []): dependency_project = names_projects.get(dependency) if dependency_project is None: raise Exception(f"Cannot find dependency {dependency}") cls.logger.info(f"For Project {project.full_name}, installing dependency {dependency}") cls.install_coq_project(dependency_project, names_projects) # end for if "build_cmd" not in project.data: raise Exception(f"Project {project.full_name} does not have build_cmd") if "install_cmd" not in project.data: raise Exception(f"Project {project.full_name} does not have install_cmd") with IOUtils.cd(project.checkout_dir): # Build cls.logger.info(f"Project {project.full_name}: Building with {project.data['build_cmd']}") r = BashUtils.run(project.data["build_cmd"]) if r.return_code != 0: raise Exception(f"Compilation failed! Return code is {r.return_code}! stdout:\n{r.stdout}\n; stderr:\n{r.stderr}") else: cls.logger.debug(f"Compilation finished. Return code is {r.return_code}. stdout:\n{r.stdout}\n; stderr:\n{r.stderr}") # end if # Install cls.logger.info(f"Project {project.full_name}: Installing with {project.data['install_cmd']}") r = BashUtils.run(project.data["install_cmd"]) if r.return_code != 0: raise Exception(f"Installation failed! Return code is {r.return_code}! stdout:\n{r.stdout}\n; stderr:\n{r.stderr}") else: cls.logger.debug(f"Installation finished. Return code is {r.return_code}. stdout:\n{r.stdout}\n; stderr:\n{r.stderr}") # end if IOUtils.dump(project.checkout_dir / confirmation_file, confirmation_content, "txt") # end with return
def require_collector(cls): if cls.is_parallel: return if not cls.collector_installed: cls.logger.info("Require collector, installing ...") with IOUtils.cd(Macros.collector_dir): BashUtils.run(f"mvn clean install -DskipTests", expected_return_code=0) # end with cls.collector_installed = True else: cls.logger.debug("Require collector, and already installed") # end if return
def test_cd(self): with TestSupport.get_playground_path(): oldpath = Path.cwd() testpath = Path("./aaa").resolve() testpath.mkdir() with IOUtils.cd(testpath): # Checks if changed directory successfully self.assertEqual(testpath, Path.cwd()) # end with # Checks if returned to old directory successfully self.assertEqual(oldpath, Path.cwd()) # end with return
def preprocess(self, train_processed_data_dir: Path, val_processed_data_dir: Path, output_model_dir: Path ) -> NoReturn: # Call OpenNMT preprocess with IOUtils.cd(self.open_nmt_path): from preprocess import _get_parser as preprocess_get_parser from preprocess import main as preprocess_main parser = preprocess_get_parser() opt = parser.parse_args( f" -train_src {train_processed_data_dir}/src.txt" f" -train_tgt {train_processed_data_dir}/tgt.txt" f" -valid_src {val_processed_data_dir}/src.txt" f" -valid_tgt {val_processed_data_dir}/tgt.txt" f" -save_data {output_model_dir}/processed-data" ) opt.src_seq_length = self.config.input_max opt.src_words_min_frequency = self.config.vocab_input_frequency_threshold if self.config.use_copy: opt.dynamic_dict = True preprocess_main(opt) # end with return
def collect_coq_documents_project( cls, data_mgr: FilesManager, project: Project, names_projects: Dict[str, Project], files: List[str] = None, is_verifying_tokenizer: bool = False, ) -> List[CoqDocument]: coq_documents: List[CoqDocument] = list() # Clone and checkout repo project.clone() project.checkout(project.data["sha"], is_forced=True) # Build the project cls.install_coq_project(project, names_projects) # For each file, parse code to tokens with IOUtils.cd(project.checkout_dir): coq_files: List[str] = BashUtils.run( f"find -name '*.v' -type f").stdout.split("\n")[:-1] if files is not None: coq_files = [f for f in coq_files if f[2:] in files] # [2:] is to remove the ./ # end if re_ignore_path = re.compile( project.data["ignore_path_regex"] ) if "ignore_path_regex" in project.data else None for i, coq_file in enumerate(coq_files): try: coq_file = coq_file[2:] cls.logger.debug( f"File {i + 1}/{len(coq_files)}: {coq_file}") # Check if file is ignored if re_ignore_path is not None and re_ignore_path.fullmatch( coq_file): cls.logger.info(f"Ignoring file {coq_file}") continue # end if # Read file with open(coq_file, "r", newline="") as f: source_code = f.read() # end with # Get unicode offsets unicode_offsets = ParserUtils.get_unicode_offsets( source_code) # Save original file to original_files data_mgr.dump_data([ FilesManager.ORIGINAL_FILES, project.full_name, coq_file ], source_code, IOUtils.Format.txt) # Call SerAPI serapi_options = project.data.get("serapi_options", "") ast_sexp_str: str = BashUtils.run( f"sercomp {serapi_options} --mode=sexp -- {coq_file}", expected_return_code=0).stdout tok_sexp_str: str = BashUtils.run( f"sertok {serapi_options} -- {coq_file}", expected_return_code=0).stdout # Save ast sexp to dataset (.ast.sexp) data_mgr.dump_data([ FilesManager.RAW_FILES, project.full_name, coq_file[:-2] + ".ast.sexp" ], ast_sexp_str, IOUtils.Format.txt) # Save tok sexp to dataset (.tok.sexp) data_mgr.dump_data([ FilesManager.RAW_FILES, project.full_name, coq_file[:-2] + ".tok.sexp" ], tok_sexp_str, IOUtils.Format.txt) # Parse ast sexp ast_sexp_list: List[SexpNode] = SexpParser.parse_list( ast_sexp_str) tok_sexp_list: List[SexpNode] = SexpParser.parse_list( tok_sexp_str) # Verify the tokenizer if requested if is_verifying_tokenizer: if not cls.verify_tokenizer(tok_sexp_list, source_code, unicode_offsets): LoggingUtils.log_and_raise( cls.logger, "Tokenized content doesn't match original file!", Exception) # end if # end if # Parse the document coq_document = CoqParser.parse_document( source_code, ast_sexp_list, tok_sexp_list, unicode_offsets=unicode_offsets) # Save the parsed document (printed format) to raw_files data_mgr.dump_data( [FilesManager.RAW_FILES, project.full_name, coq_file], coq_document.str_with_space(), IOUtils.Format.txt) # Set meta data coq_document.file_name = coq_file coq_document.project_name = project.full_name coq_document.revision = project.revision coq_documents.append(coq_document) except KeyboardInterrupt: cls.logger.warning("Keyboard interrupt!") raise except: cls.logger.warning( f"File {coq_file} failed! Exception was: {traceback.format_exc()}" ) continue # end try # end for # end with return coq_documents
def get_git_url(cls): with IOUtils.cd(Macros.project_dir): return BashUtils.run(f"git config --get remote.origin.url", expected_return_code=0).stdout.strip()
def collect_project(self, project_name: str, project_url: str): Environment.require_collector() # 0. Download repo downloads_dir = self.repos_downloads_dir / project_name results_dir = self.repos_results_dir / project_name # Remove previous results if any IOUtils.rm_dir(results_dir) IOUtils.mk_dir(results_dir) # Clone the repo if not exists if not downloads_dir.exists(): with IOUtils.cd(self.repos_downloads_dir): with TimeUtils.time_limit(300): BashUtils.run(f"git clone {project_url} {project_name}", expected_return_code=0) # end with # end with # end if project_data = ProjectData.create() project_data.name = project_name project_data.url = project_url # 1. Get list of revisions with IOUtils.cd(downloads_dir): git_log_out = BashUtils.run(f"git log --pretty=format:'%H %P'", expected_return_code=0).stdout for line in git_log_out.splitlines()[:self.MAX_REVISIONS]: shas = line.split() project_data.revisions.append(shas[0]) project_data.parent_revisions[shas[0]] = shas[1:] # end for # end with # 2. Get revisions in different year with IOUtils.cd(downloads_dir): for year in self.YEARS: git_log_out = BashUtils.run( f"git rev-list -1 --before=\"Jan 1 {year}\" origin", expected_return_code=0).stdout project_data.year_revisions[str(year) + "_Jan_1"] = git_log_out.rstrip() # end for # end with project_data_file = results_dir / "project.json" IOUtils.dump(project_data_file, IOUtils.jsonfy(project_data), IOUtils.Format.jsonPretty) # 2. Start java collector # Prepare config log_file = results_dir / "collector-log.txt" output_dir = results_dir / "collector" config = { "collect": True, "projectDir": str(downloads_dir), "projectDataFile": str(project_data_file), "logFile": str(log_file), "outputDir": str(output_dir), "year": True # To indicate whether to collect all evo data or yearly data } config_file = results_dir / "collector-config.json" IOUtils.dump(config_file, config, IOUtils.Format.jsonPretty) self.logger.info( f"Starting the Java collector. Check log at {log_file} and outputs at {output_dir}" ) rr = BashUtils.run( f"java -jar {Environment.collector_jar} {config_file}", expected_return_code=0) if rr.stderr: self.logger.warning(f"Stderr of collector:\n{rr.stderr}") # end if # 3. In some cases, save collected data to appropriate location or database # TODO private info # On luzhou server for user pynie, move it to a dedicated location at /user/disk2 if BashUtils.run( f"hostname").stdout.strip() == "luzhou" and BashUtils.run( f"echo $USER").stdout.strip() == "pynie": alter_results_dir = Path( "/home/disk2/pynie/csevo-results") / project_name IOUtils.rm_dir(alter_results_dir) IOUtils.mk_dir(alter_results_dir.parent) BashUtils.run(f"mv {results_dir} {alter_results_dir}") self.logger.info(f"Results moved to {alter_results_dir}") # end if # -1. Remove repo IOUtils.rm_dir(downloads_dir) return
def make_plot_draft_learning_curve(self, training_log_path: Path, output_name: str, ): special_plots_dir = self.plots_dir / "draft-learning-curve" IOUtils.mk_dir(special_plots_dir) fig: plt.Figure = plt.figure(figsize=(12,9)) # TODO: these metrics may be specific to Code2Seq only x_field = "batch" yl_field = "training_loss" yr_field = "eval F1" x_min = 0 x_max = -np.Inf yl_min = np.Inf yl_max = -np.Inf yr_min = np.Inf yr_max = -np.Inf # First, get ranges for all metrics (we want to use same ranges in all subplots) tvt_2_training_log = dict() tvt_2_x = dict() tvt_2_yl = dict() tvt_2_yr = dict() for tvt in [Macros.lat_lat, Macros.evo_lat, Macros.lat_evo, Macros.evo_evo]: # TODO: this path is hardcoded and work for Code2Seq 1 trial training_log = IOUtils.load(training_log_path / tvt / "trial-0" / "logs" / "train_log.json", IOUtils.Format.json) x = [d[x_field] for d in training_log] yl = [d[yl_field] for d in training_log] yr = [d[yr_field] for d in training_log] tvt_2_training_log[tvt] = training_log tvt_2_x[tvt] = x tvt_2_yl[tvt] = yl tvt_2_yr[tvt] = yr x_min = min(x_min, min(x)) x_max = max(x_max, max(x)) yl_min = min(yl_min, min(yl)) yl_max = max(yl_max, max(yl)) yr_min = min(yr_min, min(yr)) yr_max = max(yr_max, max(yr)) # end for x_lim = (x_min - (x_max - x_min) / 30, x_max + (x_max - x_min) / 30) yl_lim = (np.exp(np.log(yl_min) - (np.log(yl_max) - np.log(yl_min)) / 30), np.exp(np.log(yl_max) + (np.log(yl_max) - np.log(yl_min)) / 30)) yr_lim = (yr_min - (yr_max - yr_min) / 30, yr_max + (yr_max - yr_min) / 30) for t_i, t in enumerate([Macros.lat, Macros.evo]): for vt_i, vt in enumerate([Macros.lat, Macros.evo]): tvt = f"{t}-{vt}" tvt_i = (t_i)*2+(vt_i)+1 x = tvt_2_x[tvt] yl = tvt_2_yl[tvt] yr = tvt_2_yr[tvt] axl: plt.Axes = fig.add_subplot(2, 2, tvt_i) axr = axl.twinx() colorl = "tab:red" colorr = "tab:blue" axl.plot(x, yl, color=colorl) axr.plot(x, yr, color=colorr) axl.set_xlabel(x_field) axl.set_xlim(x_lim[0], x_lim[1]) axl.set_ylabel(yl_field, color=colorl) axl.set_yscale("log") axl.set_ylim(yl_lim[0], yl_lim[1]) axr.set_ylabel(yr_field, color=colorr) axr.set_ylim(yr_lim[0], yr_lim[1]) axl.set_title(tvt) # end for # end for fig.tight_layout() with IOUtils.cd(special_plots_dir): fig.savefig(f"{output_name}.eps") # end with return
def eval_impl(self, processed_data_dir: Path, model_dir: Path, beam_search_size: int, k: int) -> List[List[Tuple[str, float]]]: from roosterize.ml.onmt.MultiSourceTranslator import MultiSourceTranslator from onmt.utils.misc import split_corpus from onmt.utils.parse import ArgumentParser from translate import _get_parser as translate_get_parser src_path = processed_data_dir / "src.txt" tgt_path = processed_data_dir / "tgt.txt" best_step = IOUtils.load(model_dir / "best-step.json", IOUtils.Format.json) self.logger.info(f"Taking best step at {best_step}") candidates_logprobs: List[List[Tuple[List[str], float]]] = list() with IOUtils.cd(self.open_nmt_path): parser = translate_get_parser() opt = parser.parse_args( f" -model {model_dir}/models/ckpt_step_{best_step}.pt" f" -src {src_path}" f" -tgt {tgt_path}") opt.output = f"{model_dir}/last-pred.txt" opt.beam_size = beam_search_size opt.gpu = 0 if torch.cuda.is_available() else -1 opt.n_best = k opt.block_ngram_repeat = 1 opt.ignore_when_blocking = ["_"] # translate.main ArgumentParser.validate_translate_opts(opt) translator = MultiSourceTranslator.build_translator( self.config.get_src_types(), opt, report_score=False) has_target = True raw_data_keys = [ f"src.{src_type}" for src_type in self.config.get_src_types() ] + (["tgt"] if has_target else []) raw_data_paths: Dict[str, str] = { k: f"{processed_data_dir}/{k}.txt" for k in raw_data_keys } raw_data_shards: Dict[str, list] = { k: list(split_corpus(p, opt.shard_size)) for k, p in raw_data_paths.items() } # src_shards = split_corpus(opt.src, opt.shard_size) # tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None) # shard_pairs = zip(src_shards, tgt_shards) for i in range(len(list(raw_data_shards.values())[0])): self.logger.info("Translating shard %d." % i) _, _, candidates_logprobs_shard = translator.translate( {k: v[i] for k, v in raw_data_shards.items()}, has_target, src_dir=None, batch_size=opt.batch_size, attn_debug=opt.attn_debug) candidates_logprobs.extend(candidates_logprobs_shard) # end for # end with # Reformat candidates candidates_logprobs: List[List[Tuple[str, float]]] = [[ ("".join(c), l) for c, l in cl ] for cl in candidates_logprobs] return candidates_logprobs
def train_impl( self, train_processed_data_dir: Path, val_processed_data_dir: Path, output_model_dir: Path, ) -> NoReturn: self.preprocess(train_processed_data_dir, val_processed_data_dir, output_model_dir) from train import _get_parser as train_get_parser from train import ErrorHandler, batch_producer from roosterize.ml.onmt.MultiSourceInputter import MultiSourceInputter from onmt.inputters.inputter import old_style_vocab, load_old_vocab import onmt.utils.distributed from onmt.utils.parse import ArgumentParser with IOUtils.cd(self.open_nmt_path): parser = train_get_parser() opt = parser.parse_args( f" -data {output_model_dir}/processed-data" f" -save_model {output_model_dir}/models/ckpt") opt.gpu_ranks = [0] opt.early_stopping = self.config.early_stopping_threshold opt.report_every = 200 opt.valid_steps = 200 opt.save_checkpoint_steps = 200 opt.keep_checkpoint_max = self.config.ckpt_keep_max opt.optim = "adam" opt.learning_rate = self.config.learning_rate opt.max_grad_norm = self.config.max_grad_norm opt.batch_size = self.config.batch_size opt.encoder_type = self.config.encoder opt.decoder_type = self.config.decoder opt.dropout = [self.config.dropout] opt.src_word_vec_size = self.config.dim_embed opt.tgt_word_vec_size = self.config.dim_embed opt.layers = self.config.rnn_num_layers opt.enc_rnn_size = self.config.dim_encoder_hidden opt.dec_rnn_size = self.config.dim_decoder_hidden opt.__setattr__("num_srcs", len(self.config.get_src_types())) if self.config.use_attn: opt.global_attention = "general" else: opt.global_attention = "none" # end if if self.config.use_copy: opt.copy_attn = True opt.copy_attn_type = "general" # end if # train.main ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: self.logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load( opt.train_from, map_location=lambda storage, loc: storage) self.logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') # end if # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # end if if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) # end for train_iter = MultiSourceInputter.build_dataset_iter_multiple( self.config.get_src_types(), train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" # end if train_iter = MultiSourceInputter.build_dataset_iter( self.config.get_src_types(), shard_base, fields, opt) # end if nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] def run(opt, device_id, error_queue, batch_queue, semaphore): """ run process """ try: gpu_rank = onmt.utils.distributed.multi_init( opt, device_id) if gpu_rank != opt.gpu_ranks[device_id]: raise AssertionError( "An error occurred in Distributed initialization" ) self.train_single(opt, device_id, batch_queue, semaphore) except KeyboardInterrupt: pass # killed by parent, do nothing except Exception: # propagate exception to parent process, keeping original traceback import traceback error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) # end try # end def procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() self.logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) # end for producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() elif nb_gpu == 1: # case 1 GPU only self.train_single(output_model_dir, opt, 0) else: # case only CPU self.train_single(output_model_dir, opt, -1) # end if # end with return
def extract_data_project( cls, project_path: Path, files: Optional[List[str]], exclude_files: Optional[List[str]], exclude_pattern: Optional[str], serapi_options: str, output_path: Path, ): # 1. Prepare output path if output_path.is_dir(): cls.logger.warning( f"{output_path} already exists, will overwrite the files.") elif output_path.is_file(): LoggingUtils.log_and_raise( cls.logger, f"{output_path} already exists as a file. Aborting.", Exception) else: IOUtils.mk_dir(output_path) # end if # 2. Extract documents, tok.sexp and ast.sexp coq_documents: Dict[str, CoqDocument] = collections.OrderedDict() ast_sexp_lists: Dict[str, List[SexpNode]] = dict() tok_sexp_lists: Dict[str, List[SexpNode]] = dict() with IOUtils.cd(project_path): coq_files: List[str] = BashUtils.run( f"find -name '*.v' -type f").stdout.split("\n")[:-1] coq_files = [coq_file[2:] for coq_file in coq_files] if files is not None: coq_files = [f for f in coq_files if f in files] # end if if exclude_files is not None: coq_files = [f for f in coq_files if f not in exclude_files] # end if if exclude_pattern is not None: re_exclude_pattern = re.compile(exclude_pattern) coq_files = [ f for f in coq_files if not re_exclude_pattern.fullmatch(f) ] # end if for i, coq_file in enumerate(tqdm(coq_files)): try: # Read file with open(coq_file, "r", newline="") as f: source_code = f.read() # end with # Get unicode offsets unicode_offsets = ParserUtils.get_unicode_offsets( source_code) # Call SerAPI ast_sexp_str: str = BashUtils.run( f"sercomp {serapi_options} --mode=sexp -- {coq_file}", expected_return_code=0).stdout tok_sexp_str: str = BashUtils.run( f"sertok {serapi_options} -- {coq_file}", expected_return_code=0).stdout # Parse ast sexp ast_sexp_list: List[SexpNode] = SexpParser.parse_list( ast_sexp_str) tok_sexp_list: List[SexpNode] = SexpParser.parse_list( tok_sexp_str) # Parse the document coq_document = CoqParser.parse_document( source_code, ast_sexp_list, tok_sexp_list, unicode_offsets=unicode_offsets) # Set meta data coq_document.file_name = coq_file coq_document.project_name = project_path.name coq_documents[coq_file] = coq_document ast_sexp_lists[coq_file] = ast_sexp_list tok_sexp_lists[coq_file] = tok_sexp_list except KeyboardInterrupt: cls.logger.warning("Keyboard interrupt!") raise except: cls.logger.warning( f"File {coq_file} failed! Exception was: {traceback.format_exc()}" ) continue # end try # end for # 3. Extract and save lemmas and definitions lemmas: List[Lemma] = list() definitions: List[Definition] = list() # Increase recursion limit because the backend sexps are CRAZZZZY deep sys.setrecursionlimit(10000) for file_path, doc in tqdm(coq_documents.items()): ast_sexp_list = ast_sexp_lists[file_path] lemmas_doc = cls.collect_lemmas_doc(doc, ast_sexp_list, serapi_options) lemmas.extend(lemmas_doc) definitions_doc = cls.collect_definitions_doc( doc, ast_sexp_list) definitions.extend(definitions_doc) # end for IOUtils.dump(output_path / "lemmas.json", IOUtils.jsonfy(lemmas), IOUtils.Format.json) IOUtils.dump(output_path / "definitions.json", IOUtils.jsonfy(definitions), IOUtils.Format.json) # end with return