Esempio n. 1
0
    def prepare_code(self):
        IOUtils.rm_dir(self.code_dir)
        IOUtils.mk_dir(self.code_dir.parent)
        with IOUtils.cd(self.code_dir.parent):
            BashUtils.run(f"git clone {self.REPO_URL} {self.code_dir.name}", expected_return_code=0)
        # end with

        with IOUtils.cd(self.code_dir):
            BashUtils.run(f"git checkout {self.REPO_SHA}", expected_return_code=0)
        # end with
        return
Esempio n. 2
0
    def prepare_code(self):
        IOUtils.rm_dir(self.code_dir)
        IOUtils.mk_dir(self.code_dir.parent)
        with IOUtils.cd(self.code_dir.parent):
            BashUtils.run(f"git clone {self.REPO_URL} {self.code_dir.name}", expected_return_code=0)
        # end with

        with IOUtils.cd(self.code_dir):
            BashUtils.run(f"git checkout {self.REPO_SHA}", expected_return_code=0)
        # end with

        # copy eval code
        BashUtils.run(f"cp {Macros.this_dir}/eval/eval_utils.py {self.code_dir}/")
        return
    def download_global_model(self, force_yes: bool = False):
        """
        Downloads a global Roosterize model.
        """
        global_model_dir = RoosterizeDirUtils.get_global_model_dir()
        if global_model_dir.exists():
            ans = self.ask_for_confirmation(
                f"A Roosterize model already exists at {global_model_dir}. "
                f"Do you want to delete it and download again?")
            if force_yes:
                ans = True
            if ans != True:
                return
            IOUtils.rm_dir(global_model_dir)

        self.show_message("Downloading Roosterize model...")

        # Download and unpack
        temp_model_dir = Path(tempfile.mkdtemp(prefix="roosterize"))

        urllib.request.urlretrieve(self.model_url,
                                   str(temp_model_dir / "model.tgz"))
        with IOUtils.cd(temp_model_dir):
            BashUtils.run("tar xzf model.tgz", expected_return_code=0)

            # Move the stuff to global model place
            shutil.move(str(Path.cwd() / "model"), global_model_dir)

        # Delete temp dir
        IOUtils.rm_dir(temp_model_dir)

        self.show_message("Finish downloading Roosterize model.")
    def suggest_naming(self, file_path: Path, prj_root: Optional[Path] = None):
        """
        Processes a file to get its lemmas and runs the model to get predictions.
        """
        # Figure out which project we're at, and then load configs
        if prj_root is None:
            prj_root = RoosterizeDirUtils.auto_infer_project_root(file_path)
        self.load_configs(prj_root)

        # Infer SerAPI options
        serapi_options = self.infer_serapi_options(prj_root)

        # If user provided compile_cmd, first compile the project
        if self.compile_cmd is not None:
            with IOUtils.cd(prj_root):
                BashUtils.run(self.compile_cmd, expected_return_code=0)

        # Parse file
        data = self.parse_file(file_path, prj_root, serapi_options)

        # Load model
        self.load_local_model(prj_root)
        model = self.get_model()

        # Use the model to make predictions
        # Temp dirs for processed data and results
        temp_data_dir = Path(tempfile.mkdtemp(prefix="roosterize"))

        # Dump lemmas & definitions
        temp_raw_data_dir = temp_data_dir / "raw"
        temp_raw_data_dir.mkdir()
        IOUtils.dump(
            temp_raw_data_dir / "lemmas.json",
            IOUtils.jsonfy(data.lemmas),
            IOUtils.Format.json,
        )
        IOUtils.dump(
            temp_raw_data_dir / "definitions.json",
            IOUtils.jsonfy(data.definitions),
            IOUtils.Format.json,
        )

        # Model-specific process
        temp_processed_data_dir = temp_data_dir / "processed"
        temp_processed_data_dir.mkdir()
        model.process_data_impl(temp_raw_data_dir, temp_processed_data_dir)

        # Invoke eval
        candidates_logprobs = model.eval_impl(
            temp_processed_data_dir,
            beam_search_size=self.beam_search_size,
            k=self.k,
        )

        # Save predictions
        IOUtils.rm_dir(temp_data_dir)

        # Report predictions
        self.report_predictions(data, candidates_logprobs)
        return
    def parse_file(self, file_path: Path, prj_root: Path, serapi_options: str):
        source_code = IOUtils.load(file_path, IOUtils.Format.txt)
        unicode_offsets = ParserUtils.get_unicode_offsets(source_code)

        with IOUtils.cd(prj_root):
            rel_path = file_path.relative_to(prj_root)
            ast_sexp_str = BashUtils.run(
                f"sercomp {serapi_options} --mode=sexp -- {rel_path}",
                expected_return_code=0).stdout
            tok_sexp_str = BashUtils.run(
                f"sertok {serapi_options} -- {rel_path}",
                expected_return_code=0).stdout

            ast_sexp_list: List[SexpNode] = SexpParser.parse_list(ast_sexp_str)
            tok_sexp_list: List[SexpNode] = SexpParser.parse_list(tok_sexp_str)

            doc = CoqParser.parse_document(
                source_code,
                ast_sexp_list,
                tok_sexp_list,
                unicode_offsets=unicode_offsets,
            )
            doc.file_name = str(rel_path)

            # Collect lemmas & definitions
            lemmas: List[Lemma] = DataMiner.collect_lemmas_doc(
                doc, ast_sexp_list, serapi_options)
            definitions: List[Definition] = DataMiner.collect_definitions_doc(
                doc, ast_sexp_list)

        return ProcessedFile(file_path, source_code, doc, ast_sexp_list,
                             tok_sexp_list, unicode_offsets, lemmas,
                             definitions)
    def eval_impl(self,
            processed_data_dir: Path,
            model_dir: Path,
            beam_search_size: int,
            k: int
    ) -> List[List[Tuple[str, float]]]:
        from roosterize.ml.onmt.CustomTranslator import CustomTranslator
        from onmt.utils.misc import split_corpus
        from onmt.utils.parse import ArgumentParser
        from translate import _get_parser as translate_get_parser

        src_path = processed_data_dir/"src.txt"
        tgt_path = processed_data_dir/"tgt.txt"

        best_step = IOUtils.load(model_dir/"best-step.json", IOUtils.Format.json)
        self.logger.info(f"Taking best step at {best_step}")

        candidates_logprobs: List[List[Tuple[List[str], float]]] = list()

        with IOUtils.cd(self.open_nmt_path):
            parser = translate_get_parser()
            opt = parser.parse_args(
                f" -model {model_dir}/models/ckpt_step_{best_step}.pt"
                f" -src {src_path}"
                f" -tgt {tgt_path}"
            )
            opt.output = f"{model_dir}/last-pred.txt"
            opt.beam_size = beam_search_size
            opt.gpu = 0 if torch.cuda.is_available() else -1
            opt.n_best = k
            opt.block_ngram_repeat = 1
            opt.ignore_when_blocking = ["_"]

            # translate.main
            ArgumentParser.validate_translate_opts(opt)

            translator = CustomTranslator.build_translator(opt, report_score=False)
            src_shards = split_corpus(opt.src, opt.shard_size)
            tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None)
            shard_pairs = zip(src_shards, tgt_shards)

            for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
                self.logger.info("Translating shard %d." % i)
                _, _, candidates_logprobs_shard = translator.translate(
                    src=src_shard,
                    tgt=tgt_shard,
                    src_dir=opt.src_dir,
                    batch_size=opt.batch_size,
                    attn_debug=opt.attn_debug
                )
                candidates_logprobs.extend(candidates_logprobs_shard)
            # end for
        # end with

        # Reformat candidates
        candidates_logprobs: List[List[Tuple[str, float]]] = [[("".join(c), l) for c, l in cl] for cl in candidates_logprobs]

        return candidates_logprobs
Esempio n. 7
0
    def require_special_repo(cls, directory: Path, branch: str):
        cls.logger.info(f"Updating {directory} to {branch} branch")
        if directory.exists():
            if not directory.is_dir() or not (directory / ".git").is_dir():
                LoggingUtils.log_and_raise(
                    cls.logger,
                    f"Path {directory} already exists but is not a proper git repository!",
                    Exception)
            # end if

            with IOUtils.cd(directory):
                BashUtils.run(f"git pull", expected_return_code=0)
            # end with
        else:
            IOUtils.mk_dir(directory)
            with IOUtils.cd(directory):
                BashUtils.run(
                    f"git clone --single-branch -b {branch} -- {cls.get_git_url()} .",
                    expected_return_code=0)
Esempio n. 8
0
    def install_coq_project(cls, project: Project, names_projects: Dict[str, Project]) -> None:
        """
        :requires: the project is cloned and checked-out to the desired version.
        """
        if not project.is_cloned:
            project.clone()
            project.checkout(project.data["sha"], is_forced=True)
        # end if

        # Check if the project is already compiled
        confirmation_file = "lpc-installed.txt"
        confirmation_content = project.revision + " " + BashUtils.run("opam list coq -s", expected_return_code=0).stdout.strip()
        if (project.checkout_dir/confirmation_file).is_file() and IOUtils.load(project.checkout_dir/confirmation_file, "txt") == confirmation_content:
            cls.logger.debug(f"Project {project.full_name} already installed")
            return
        # end if

        project.clean()

        # Install dependencies
        for dependency in project.data.get("dependencies", []):
            dependency_project = names_projects.get(dependency)
            if dependency_project is None:  raise Exception(f"Cannot find dependency {dependency}")
            cls.logger.info(f"For Project {project.full_name}, installing dependency {dependency}")
            cls.install_coq_project(dependency_project, names_projects)
        # end for

        if "build_cmd" not in project.data:  raise Exception(f"Project {project.full_name} does not have build_cmd")
        if "install_cmd" not in project.data:  raise Exception(f"Project {project.full_name} does not have install_cmd")

        with IOUtils.cd(project.checkout_dir):
            # Build
            cls.logger.info(f"Project {project.full_name}: Building with {project.data['build_cmd']}")
            r = BashUtils.run(project.data["build_cmd"])
            if r.return_code != 0:
                raise Exception(f"Compilation failed! Return code is {r.return_code}! stdout:\n{r.stdout}\n; stderr:\n{r.stderr}")
            else:
                cls.logger.debug(f"Compilation finished. Return code is {r.return_code}. stdout:\n{r.stdout}\n; stderr:\n{r.stderr}")
            # end if

            # Install
            cls.logger.info(f"Project {project.full_name}: Installing with {project.data['install_cmd']}")
            r = BashUtils.run(project.data["install_cmd"])
            if r.return_code != 0:
                raise Exception(f"Installation failed! Return code is {r.return_code}! stdout:\n{r.stdout}\n; stderr:\n{r.stderr}")
            else:
                cls.logger.debug(f"Installation finished. Return code is {r.return_code}. stdout:\n{r.stdout}\n; stderr:\n{r.stderr}")
            # end if

            IOUtils.dump(project.checkout_dir / confirmation_file, confirmation_content, "txt")
        # end with
        return
Esempio n. 9
0
 def require_collector(cls):
     if cls.is_parallel: return
     if not cls.collector_installed:
         cls.logger.info("Require collector, installing ...")
         with IOUtils.cd(Macros.collector_dir):
             BashUtils.run(f"mvn clean install -DskipTests",
                           expected_return_code=0)
         # end with
         cls.collector_installed = True
     else:
         cls.logger.debug("Require collector, and already installed")
     # end if
     return
Esempio n. 10
0
 def test_cd(self):
     with TestSupport.get_playground_path():
         oldpath = Path.cwd()
         testpath = Path("./aaa").resolve()
         testpath.mkdir()
         with IOUtils.cd(testpath):
             # Checks if changed directory successfully
             self.assertEqual(testpath, Path.cwd())
         # end with
         # Checks if returned to old directory successfully
         self.assertEqual(oldpath, Path.cwd())
     # end with
     return
 def preprocess(self,
         train_processed_data_dir: Path,
         val_processed_data_dir: Path,
         output_model_dir: Path
 ) -> NoReturn:
     # Call OpenNMT preprocess
     with IOUtils.cd(self.open_nmt_path):
         from preprocess import _get_parser as preprocess_get_parser
         from preprocess import main as preprocess_main
         parser = preprocess_get_parser()
         opt = parser.parse_args(
             f" -train_src {train_processed_data_dir}/src.txt"
             f" -train_tgt {train_processed_data_dir}/tgt.txt"
             f" -valid_src {val_processed_data_dir}/src.txt"
             f" -valid_tgt {val_processed_data_dir}/tgt.txt"
             f" -save_data {output_model_dir}/processed-data"
         )
         opt.src_seq_length = self.config.input_max
         opt.src_words_min_frequency = self.config.vocab_input_frequency_threshold
         if self.config.use_copy:  opt.dynamic_dict = True
         preprocess_main(opt)
     # end with
     return
Esempio n. 12
0
    def collect_coq_documents_project(
        cls,
        data_mgr: FilesManager,
        project: Project,
        names_projects: Dict[str, Project],
        files: List[str] = None,
        is_verifying_tokenizer: bool = False,
    ) -> List[CoqDocument]:
        coq_documents: List[CoqDocument] = list()

        # Clone and checkout repo
        project.clone()
        project.checkout(project.data["sha"], is_forced=True)

        # Build the project
        cls.install_coq_project(project, names_projects)

        # For each file, parse code to tokens
        with IOUtils.cd(project.checkout_dir):
            coq_files: List[str] = BashUtils.run(
                f"find -name '*.v' -type f").stdout.split("\n")[:-1]
            if files is not None:
                coq_files = [f for f in coq_files
                             if f[2:] in files]  # [2:] is to remove the ./
            # end if
            re_ignore_path = re.compile(
                project.data["ignore_path_regex"]
            ) if "ignore_path_regex" in project.data else None
            for i, coq_file in enumerate(coq_files):
                try:
                    coq_file = coq_file[2:]
                    cls.logger.debug(
                        f"File {i + 1}/{len(coq_files)}: {coq_file}")

                    # Check if file is ignored
                    if re_ignore_path is not None and re_ignore_path.fullmatch(
                            coq_file):
                        cls.logger.info(f"Ignoring file {coq_file}")
                        continue
                    # end if

                    # Read file
                    with open(coq_file, "r", newline="") as f:
                        source_code = f.read()
                    # end with

                    # Get unicode offsets
                    unicode_offsets = ParserUtils.get_unicode_offsets(
                        source_code)

                    # Save original file to original_files
                    data_mgr.dump_data([
                        FilesManager.ORIGINAL_FILES, project.full_name,
                        coq_file
                    ], source_code, IOUtils.Format.txt)

                    # Call SerAPI
                    serapi_options = project.data.get("serapi_options", "")
                    ast_sexp_str: str = BashUtils.run(
                        f"sercomp {serapi_options} --mode=sexp -- {coq_file}",
                        expected_return_code=0).stdout
                    tok_sexp_str: str = BashUtils.run(
                        f"sertok {serapi_options} -- {coq_file}",
                        expected_return_code=0).stdout

                    # Save ast sexp to dataset (.ast.sexp)
                    data_mgr.dump_data([
                        FilesManager.RAW_FILES, project.full_name,
                        coq_file[:-2] + ".ast.sexp"
                    ], ast_sexp_str, IOUtils.Format.txt)

                    # Save tok sexp to dataset (.tok.sexp)
                    data_mgr.dump_data([
                        FilesManager.RAW_FILES, project.full_name,
                        coq_file[:-2] + ".tok.sexp"
                    ], tok_sexp_str, IOUtils.Format.txt)

                    # Parse ast sexp
                    ast_sexp_list: List[SexpNode] = SexpParser.parse_list(
                        ast_sexp_str)
                    tok_sexp_list: List[SexpNode] = SexpParser.parse_list(
                        tok_sexp_str)

                    # Verify the tokenizer if requested
                    if is_verifying_tokenizer:
                        if not cls.verify_tokenizer(tok_sexp_list, source_code,
                                                    unicode_offsets):
                            LoggingUtils.log_and_raise(
                                cls.logger,
                                "Tokenized content doesn't match original file!",
                                Exception)
                        # end if
                    # end if

                    # Parse the document
                    coq_document = CoqParser.parse_document(
                        source_code,
                        ast_sexp_list,
                        tok_sexp_list,
                        unicode_offsets=unicode_offsets)

                    # Save the parsed document (printed format) to raw_files
                    data_mgr.dump_data(
                        [FilesManager.RAW_FILES, project.full_name, coq_file],
                        coq_document.str_with_space(), IOUtils.Format.txt)

                    # Set meta data
                    coq_document.file_name = coq_file
                    coq_document.project_name = project.full_name
                    coq_document.revision = project.revision

                    coq_documents.append(coq_document)
                except KeyboardInterrupt:
                    cls.logger.warning("Keyboard interrupt!")
                    raise
                except:
                    cls.logger.warning(
                        f"File {coq_file} failed! Exception was: {traceback.format_exc()}"
                    )
                    continue
                # end try
            # end for
        # end with

        return coq_documents
Esempio n. 13
0
 def get_git_url(cls):
     with IOUtils.cd(Macros.project_dir):
         return BashUtils.run(f"git config --get remote.origin.url",
                              expected_return_code=0).stdout.strip()
Esempio n. 14
0
    def collect_project(self, project_name: str, project_url: str):
        Environment.require_collector()

        # 0. Download repo
        downloads_dir = self.repos_downloads_dir / project_name
        results_dir = self.repos_results_dir / project_name

        # Remove previous results if any
        IOUtils.rm_dir(results_dir)
        IOUtils.mk_dir(results_dir)

        # Clone the repo if not exists
        if not downloads_dir.exists():
            with IOUtils.cd(self.repos_downloads_dir):
                with TimeUtils.time_limit(300):
                    BashUtils.run(f"git clone {project_url} {project_name}",
                                  expected_return_code=0)
                # end with
            # end with
        # end if

        project_data = ProjectData.create()
        project_data.name = project_name
        project_data.url = project_url

        # 1. Get list of revisions
        with IOUtils.cd(downloads_dir):
            git_log_out = BashUtils.run(f"git log --pretty=format:'%H %P'",
                                        expected_return_code=0).stdout
            for line in git_log_out.splitlines()[:self.MAX_REVISIONS]:
                shas = line.split()
                project_data.revisions.append(shas[0])
                project_data.parent_revisions[shas[0]] = shas[1:]
            # end for
        # end with

        # 2. Get revisions in different year
        with IOUtils.cd(downloads_dir):
            for year in self.YEARS:
                git_log_out = BashUtils.run(
                    f"git rev-list -1 --before=\"Jan 1 {year}\" origin",
                    expected_return_code=0).stdout
                project_data.year_revisions[str(year) +
                                            "_Jan_1"] = git_log_out.rstrip()
            # end for
        # end with

        project_data_file = results_dir / "project.json"
        IOUtils.dump(project_data_file, IOUtils.jsonfy(project_data),
                     IOUtils.Format.jsonPretty)

        # 2. Start java collector
        # Prepare config
        log_file = results_dir / "collector-log.txt"
        output_dir = results_dir / "collector"

        config = {
            "collect": True,
            "projectDir": str(downloads_dir),
            "projectDataFile": str(project_data_file),
            "logFile": str(log_file),
            "outputDir": str(output_dir),
            "year":
            True  # To indicate whether to collect all evo data or yearly data
        }
        config_file = results_dir / "collector-config.json"
        IOUtils.dump(config_file, config, IOUtils.Format.jsonPretty)

        self.logger.info(
            f"Starting the Java collector. Check log at {log_file} and outputs at {output_dir}"
        )
        rr = BashUtils.run(
            f"java -jar {Environment.collector_jar} {config_file}",
            expected_return_code=0)
        if rr.stderr:
            self.logger.warning(f"Stderr of collector:\n{rr.stderr}")
        # end if

        # 3. In some cases, save collected data to appropriate location or database
        # TODO private info
        # On luzhou server for user pynie, move it to a dedicated location at /user/disk2
        if BashUtils.run(
                f"hostname").stdout.strip() == "luzhou" and BashUtils.run(
                    f"echo $USER").stdout.strip() == "pynie":
            alter_results_dir = Path(
                "/home/disk2/pynie/csevo-results") / project_name
            IOUtils.rm_dir(alter_results_dir)
            IOUtils.mk_dir(alter_results_dir.parent)
            BashUtils.run(f"mv {results_dir} {alter_results_dir}")
            self.logger.info(f"Results moved to {alter_results_dir}")
        # end if

        # -1. Remove repo
        IOUtils.rm_dir(downloads_dir)
        return
Esempio n. 15
0
    def make_plot_draft_learning_curve(self,
            training_log_path: Path,
            output_name: str,
    ):
        special_plots_dir = self.plots_dir / "draft-learning-curve"
        IOUtils.mk_dir(special_plots_dir)

        fig: plt.Figure = plt.figure(figsize=(12,9))

        # TODO: these metrics may be specific to Code2Seq only
        x_field = "batch"
        yl_field = "training_loss"
        yr_field = "eval F1"

        x_min = 0
        x_max = -np.Inf
        yl_min = np.Inf
        yl_max = -np.Inf
        yr_min = np.Inf
        yr_max = -np.Inf

        # First, get ranges for all metrics (we want to use same ranges in all subplots)
        tvt_2_training_log = dict()
        tvt_2_x = dict()
        tvt_2_yl = dict()
        tvt_2_yr = dict()

        for tvt in [Macros.lat_lat, Macros.evo_lat, Macros.lat_evo, Macros.evo_evo]:
            # TODO: this path is hardcoded and work for Code2Seq 1 trial
            training_log = IOUtils.load(training_log_path / tvt / "trial-0" / "logs" / "train_log.json", IOUtils.Format.json)
            x = [d[x_field] for d in training_log]
            yl = [d[yl_field] for d in training_log]
            yr = [d[yr_field] for d in training_log]

            tvt_2_training_log[tvt] = training_log
            tvt_2_x[tvt] = x
            tvt_2_yl[tvt] = yl
            tvt_2_yr[tvt] = yr

            x_min = min(x_min, min(x))
            x_max = max(x_max, max(x))
            yl_min = min(yl_min, min(yl))
            yl_max = max(yl_max, max(yl))
            yr_min = min(yr_min, min(yr))
            yr_max = max(yr_max, max(yr))
        # end for

        x_lim = (x_min - (x_max - x_min) / 30, x_max + (x_max - x_min) / 30)
        yl_lim = (np.exp(np.log(yl_min) - (np.log(yl_max) - np.log(yl_min)) / 30), np.exp(np.log(yl_max) + (np.log(yl_max) - np.log(yl_min)) / 30))
        yr_lim = (yr_min - (yr_max - yr_min) / 30, yr_max + (yr_max - yr_min) / 30)

        for t_i, t in enumerate([Macros.lat, Macros.evo]):
            for vt_i, vt in enumerate([Macros.lat, Macros.evo]):
                tvt = f"{t}-{vt}"
                tvt_i = (t_i)*2+(vt_i)+1

                x = tvt_2_x[tvt]
                yl = tvt_2_yl[tvt]
                yr = tvt_2_yr[tvt]

                axl: plt.Axes = fig.add_subplot(2, 2, tvt_i)
                axr = axl.twinx()

                colorl = "tab:red"
                colorr = "tab:blue"

                axl.plot(x, yl, color=colorl)
                axr.plot(x, yr, color=colorr)

                axl.set_xlabel(x_field)
                axl.set_xlim(x_lim[0], x_lim[1])

                axl.set_ylabel(yl_field, color=colorl)
                axl.set_yscale("log")
                axl.set_ylim(yl_lim[0], yl_lim[1])

                axr.set_ylabel(yr_field, color=colorr)
                axr.set_ylim(yr_lim[0], yr_lim[1])

                axl.set_title(tvt)
            # end for
        # end for

        fig.tight_layout()
        with IOUtils.cd(special_plots_dir):
            fig.savefig(f"{output_name}.eps")
        # end with
        return
Esempio n. 16
0
    def eval_impl(self, processed_data_dir: Path, model_dir: Path,
                  beam_search_size: int,
                  k: int) -> List[List[Tuple[str, float]]]:
        from roosterize.ml.onmt.MultiSourceTranslator import MultiSourceTranslator
        from onmt.utils.misc import split_corpus
        from onmt.utils.parse import ArgumentParser
        from translate import _get_parser as translate_get_parser

        src_path = processed_data_dir / "src.txt"
        tgt_path = processed_data_dir / "tgt.txt"

        best_step = IOUtils.load(model_dir / "best-step.json",
                                 IOUtils.Format.json)
        self.logger.info(f"Taking best step at {best_step}")

        candidates_logprobs: List[List[Tuple[List[str], float]]] = list()

        with IOUtils.cd(self.open_nmt_path):
            parser = translate_get_parser()
            opt = parser.parse_args(
                f" -model {model_dir}/models/ckpt_step_{best_step}.pt"
                f" -src {src_path}"
                f" -tgt {tgt_path}")
            opt.output = f"{model_dir}/last-pred.txt"
            opt.beam_size = beam_search_size
            opt.gpu = 0 if torch.cuda.is_available() else -1
            opt.n_best = k
            opt.block_ngram_repeat = 1
            opt.ignore_when_blocking = ["_"]

            # translate.main
            ArgumentParser.validate_translate_opts(opt)

            translator = MultiSourceTranslator.build_translator(
                self.config.get_src_types(), opt, report_score=False)

            has_target = True
            raw_data_keys = [
                f"src.{src_type}" for src_type in self.config.get_src_types()
            ] + (["tgt"] if has_target else [])
            raw_data_paths: Dict[str, str] = {
                k: f"{processed_data_dir}/{k}.txt"
                for k in raw_data_keys
            }
            raw_data_shards: Dict[str, list] = {
                k: list(split_corpus(p, opt.shard_size))
                for k, p in raw_data_paths.items()
            }

            # src_shards = split_corpus(opt.src, opt.shard_size)
            # tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None)
            # shard_pairs = zip(src_shards, tgt_shards)

            for i in range(len(list(raw_data_shards.values())[0])):
                self.logger.info("Translating shard %d." % i)
                _, _, candidates_logprobs_shard = translator.translate(
                    {k: v[i]
                     for k, v in raw_data_shards.items()},
                    has_target,
                    src_dir=None,
                    batch_size=opt.batch_size,
                    attn_debug=opt.attn_debug)
                candidates_logprobs.extend(candidates_logprobs_shard)
            # end for
        # end with

        # Reformat candidates
        candidates_logprobs: List[List[Tuple[str, float]]] = [[
            ("".join(c), l) for c, l in cl
        ] for cl in candidates_logprobs]

        return candidates_logprobs
Esempio n. 17
0
    def train_impl(
        self,
        train_processed_data_dir: Path,
        val_processed_data_dir: Path,
        output_model_dir: Path,
    ) -> NoReturn:
        self.preprocess(train_processed_data_dir, val_processed_data_dir,
                        output_model_dir)

        from train import _get_parser as train_get_parser
        from train import ErrorHandler, batch_producer
        from roosterize.ml.onmt.MultiSourceInputter import MultiSourceInputter
        from onmt.inputters.inputter import old_style_vocab, load_old_vocab
        import onmt.utils.distributed
        from onmt.utils.parse import ArgumentParser

        with IOUtils.cd(self.open_nmt_path):
            parser = train_get_parser()
            opt = parser.parse_args(
                f" -data {output_model_dir}/processed-data"
                f" -save_model {output_model_dir}/models/ckpt")
            opt.gpu_ranks = [0]
            opt.early_stopping = self.config.early_stopping_threshold
            opt.report_every = 200
            opt.valid_steps = 200
            opt.save_checkpoint_steps = 200
            opt.keep_checkpoint_max = self.config.ckpt_keep_max

            opt.optim = "adam"
            opt.learning_rate = self.config.learning_rate
            opt.max_grad_norm = self.config.max_grad_norm
            opt.batch_size = self.config.batch_size

            opt.encoder_type = self.config.encoder
            opt.decoder_type = self.config.decoder
            opt.dropout = [self.config.dropout]
            opt.src_word_vec_size = self.config.dim_embed
            opt.tgt_word_vec_size = self.config.dim_embed
            opt.layers = self.config.rnn_num_layers
            opt.enc_rnn_size = self.config.dim_encoder_hidden
            opt.dec_rnn_size = self.config.dim_decoder_hidden
            opt.__setattr__("num_srcs", len(self.config.get_src_types()))
            if self.config.use_attn:
                opt.global_attention = "general"
            else:
                opt.global_attention = "none"
            # end if
            if self.config.use_copy:
                opt.copy_attn = True
                opt.copy_attn_type = "general"
            # end if

            # train.main
            ArgumentParser.validate_train_opts(opt)
            ArgumentParser.update_model_opts(opt)
            ArgumentParser.validate_model_opts(opt)

            # Load checkpoint if we resume from a previous training.
            if opt.train_from:
                self.logger.info('Loading checkpoint from %s' % opt.train_from)
                checkpoint = torch.load(
                    opt.train_from, map_location=lambda storage, loc: storage)
                self.logger.info('Loading vocab from checkpoint at %s.' %
                                 opt.train_from)
                vocab = checkpoint['vocab']
            else:
                vocab = torch.load(opt.data + '.vocab.pt')
            # end if

            # check for code where vocab is saved instead of fields
            # (in the future this will be done in a smarter way)
            if old_style_vocab(vocab):
                fields = load_old_vocab(vocab,
                                        opt.model_type,
                                        dynamic_dict=opt.copy_attn)
            else:
                fields = vocab
            # end if

            if len(opt.data_ids) > 1:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                # end for
                train_iter = MultiSourceInputter.build_dataset_iter_multiple(
                    self.config.get_src_types(), train_shards, fields, opt)
            else:
                if opt.data_ids[0] is not None:
                    shard_base = "train_" + opt.data_ids[0]
                else:
                    shard_base = "train"
                # end if
                train_iter = MultiSourceInputter.build_dataset_iter(
                    self.config.get_src_types(), shard_base, fields, opt)
            # end if

            nb_gpu = len(opt.gpu_ranks)

            if opt.world_size > 1:
                queues = []
                mp = torch.multiprocessing.get_context('spawn')
                semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
                # Create a thread to listen for errors in the child processes.
                error_queue = mp.SimpleQueue()
                error_handler = ErrorHandler(error_queue)
                # Train with multiprocessing.
                procs = []
                for device_id in range(nb_gpu):
                    q = mp.Queue(opt.queue_size)
                    queues += [q]

                    def run(opt, device_id, error_queue, batch_queue,
                            semaphore):
                        """ run process """
                        try:
                            gpu_rank = onmt.utils.distributed.multi_init(
                                opt, device_id)
                            if gpu_rank != opt.gpu_ranks[device_id]:
                                raise AssertionError(
                                    "An error occurred in Distributed initialization"
                                )
                            self.train_single(opt, device_id, batch_queue,
                                              semaphore)
                        except KeyboardInterrupt:
                            pass  # killed by parent, do nothing
                        except Exception:
                            # propagate exception to parent process, keeping original traceback
                            import traceback
                            error_queue.put((opt.gpu_ranks[device_id],
                                             traceback.format_exc()))
                        # end try

                    # end def

                    procs.append(
                        mp.Process(target=run,
                                   args=(opt, device_id, error_queue, q,
                                         semaphore),
                                   daemon=True))
                    procs[device_id].start()
                    self.logger.info(" Starting process pid: %d  " %
                                     procs[device_id].pid)
                    error_handler.add_child(procs[device_id].pid)
                # end for
                producer = mp.Process(target=batch_producer,
                                      args=(
                                          train_iter,
                                          queues,
                                          semaphore,
                                          opt,
                                      ),
                                      daemon=True)
                producer.start()
                error_handler.add_child(producer.pid)

                for p in procs:
                    p.join()
                producer.terminate()

            elif nb_gpu == 1:  # case 1 GPU only
                self.train_single(output_model_dir, opt, 0)
            else:  # case only CPU
                self.train_single(output_model_dir, opt, -1)
            # end if
        # end with
        return
Esempio n. 18
0
    def extract_data_project(
        cls,
        project_path: Path,
        files: Optional[List[str]],
        exclude_files: Optional[List[str]],
        exclude_pattern: Optional[str],
        serapi_options: str,
        output_path: Path,
    ):
        # 1. Prepare output path
        if output_path.is_dir():
            cls.logger.warning(
                f"{output_path} already exists, will overwrite the files.")
        elif output_path.is_file():
            LoggingUtils.log_and_raise(
                cls.logger,
                f"{output_path} already exists as a file. Aborting.",
                Exception)
        else:
            IOUtils.mk_dir(output_path)
        # end if

        # 2. Extract documents, tok.sexp and ast.sexp
        coq_documents: Dict[str, CoqDocument] = collections.OrderedDict()
        ast_sexp_lists: Dict[str, List[SexpNode]] = dict()
        tok_sexp_lists: Dict[str, List[SexpNode]] = dict()

        with IOUtils.cd(project_path):
            coq_files: List[str] = BashUtils.run(
                f"find -name '*.v' -type f").stdout.split("\n")[:-1]
            coq_files = [coq_file[2:] for coq_file in coq_files]

            if files is not None:
                coq_files = [f for f in coq_files if f in files]
            # end if

            if exclude_files is not None:
                coq_files = [f for f in coq_files if f not in exclude_files]
            # end if

            if exclude_pattern is not None:
                re_exclude_pattern = re.compile(exclude_pattern)
                coq_files = [
                    f for f in coq_files if not re_exclude_pattern.fullmatch(f)
                ]
            # end if

            for i, coq_file in enumerate(tqdm(coq_files)):
                try:
                    # Read file
                    with open(coq_file, "r", newline="") as f:
                        source_code = f.read()
                    # end with

                    # Get unicode offsets
                    unicode_offsets = ParserUtils.get_unicode_offsets(
                        source_code)

                    # Call SerAPI
                    ast_sexp_str: str = BashUtils.run(
                        f"sercomp {serapi_options} --mode=sexp -- {coq_file}",
                        expected_return_code=0).stdout
                    tok_sexp_str: str = BashUtils.run(
                        f"sertok {serapi_options} -- {coq_file}",
                        expected_return_code=0).stdout

                    # Parse ast sexp
                    ast_sexp_list: List[SexpNode] = SexpParser.parse_list(
                        ast_sexp_str)
                    tok_sexp_list: List[SexpNode] = SexpParser.parse_list(
                        tok_sexp_str)

                    # Parse the document
                    coq_document = CoqParser.parse_document(
                        source_code,
                        ast_sexp_list,
                        tok_sexp_list,
                        unicode_offsets=unicode_offsets)

                    # Set meta data
                    coq_document.file_name = coq_file
                    coq_document.project_name = project_path.name

                    coq_documents[coq_file] = coq_document
                    ast_sexp_lists[coq_file] = ast_sexp_list
                    tok_sexp_lists[coq_file] = tok_sexp_list
                except KeyboardInterrupt:
                    cls.logger.warning("Keyboard interrupt!")
                    raise
                except:
                    cls.logger.warning(
                        f"File {coq_file} failed! Exception was: {traceback.format_exc()}"
                    )
                    continue
                # end try
            # end for

            # 3. Extract and save lemmas and definitions
            lemmas: List[Lemma] = list()
            definitions: List[Definition] = list()

            # Increase recursion limit because the backend sexps are CRAZZZZY deep
            sys.setrecursionlimit(10000)

            for file_path, doc in tqdm(coq_documents.items()):
                ast_sexp_list = ast_sexp_lists[file_path]
                lemmas_doc = cls.collect_lemmas_doc(doc, ast_sexp_list,
                                                    serapi_options)
                lemmas.extend(lemmas_doc)
                definitions_doc = cls.collect_definitions_doc(
                    doc, ast_sexp_list)
                definitions.extend(definitions_doc)
            # end for

            IOUtils.dump(output_path / "lemmas.json", IOUtils.jsonfy(lemmas),
                         IOUtils.Format.json)
            IOUtils.dump(output_path / "definitions.json",
                         IOUtils.jsonfy(definitions), IOUtils.Format.json)
        # end with
        return