示例#1
0
    def dump_data(self,
            rel_path: Union[str, List[str]],
            data: Any,
            fmt: IOUtils.Format,
            is_batched: bool = False,
            per_batch: int = 100,
            exist_ok: bool = False,
    ):
        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        if abs_path.exists() and not exist_ok:
            LoggingUtils.log_and_raise(self.logger, f"Cannot rewrite existing data at {abs_path}", IOError)
        # end if

        abs_path.parent.mkdir(parents=True, exist_ok=True)
        if not is_batched:
            if self.is_json_format(fmt):
                data = IOUtils.jsonfy(data)
            # end if
            IOUtils.dump(abs_path, data, fmt)
        else:
            # In batched mode, the data need to be slice-able and sizable
            IOUtils.rm(abs_path)
            abs_path.mkdir(parents=True)

            for batch_i in tqdm(range(math.ceil(len(data)/per_batch))):
                data_batch = data[per_batch*batch_i : per_batch*(batch_i+1)]
                if self.is_json_format(fmt):
                    data_batch = IOUtils.jsonfy(data_batch)
                # end if
                IOUtils.dump(abs_path/f"batch-{batch_i}.{fmt.get_extension()}", data_batch, fmt)
            # end for
        # end if
        return
示例#2
0
    def load_data(self,
            rel_path: Union[str, List[str]],
            fmt: IOUtils.Format,
            is_batched: bool = False,
            clz = None,
    ) -> Any:
        if self.is_json_format(fmt) and clz is None:
            self.logger.warning(f"Load data from {rel_path} with json format, but did not specify clz (at {traceback.format_stack()})")
        # end if

        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        if not abs_path.exists():
            LoggingUtils.log_and_raise(self.logger, f"Cannot find data at {abs_path}", IOError)
        # end if

        if not is_batched:
            data = IOUtils.load(abs_path, fmt)
            if self.is_json_format(fmt) and clz is not None:
                data = IOUtils.dejsonfy(data, clz)
            # end if
            return data
        else:
            data = list()
            batch_numbers = sorted([int(str(f.stem).split("-")[1]) for f in abs_path.iterdir()])
            for batch_number in tqdm(batch_numbers):
                batch_file = abs_path / f"batch-{batch_number}.{fmt.get_extension()}"
                data_batch = IOUtils.load(batch_file, fmt)
                if self.is_json_format(fmt) and clz is not None:
                    data_batch = IOUtils.dejsonfy(data_batch, clz)
                # end if
                data.extend(data_batch)
            # end for
            return data
示例#3
0
    def extract_data_from_corpus(cls,
            corpus_path: Path,
            trainevals: List[str],
            groups: List[str],
            output_path: Path,
    ):
        # 1. Prepare output path
        if output_path.is_dir():
            cls.logger.warning(f"{output_path} already exists, will overwrite the files.")
        elif output_path.is_file():
            LoggingUtils.log_and_raise(cls.logger, f"{output_path} already exists as a file. Aborting.", Exception)
        else:
            IOUtils.mk_dir(output_path)
        # end if

        assert all([traineval in Macros.DS_TRAINEVALS for traineval in trainevals])
        assert all([group in Macros.DS_GROUPS+[Macros.DS_GROUP_TA] for group in groups])

        data_mgr = FilesManager(corpus_path)

        # 2. Load lemmas and definitions
        lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED], IOUtils.Format.json, is_batched=True, clz=Lemma)
        definitions: List[Definition] = data_mgr.load_data([FilesManager.DEFINITIONS, "definitions.json"], IOUtils.Format.json, clz=Definition)

        # 3. Output to output_path for each combination of traineval and group
        for traineval in trainevals:
            for group in groups:
                IOUtils.mk_dir(output_path/f"{group}-{traineval}")
                data_indexes = IOUtils.load(project_dir/"training"/f"{group}-{traineval}.json"], IOUtils.Format.json, clz=str)
                IOUtils.dump(output_path/f"{group}-{traineval}/lemmas.json", IOUtils.jsonfy([l for l in lemmas_filtered if l.data_index in data_indexes]), IOUtils.Format.json)
                IOUtils.dump(output_path/f"{group}-{traineval}/definitions.json", IOUtils.jsonfy([d for d in definitions if d.data_index in data_indexes]), IOUtils.Format.json)
            # end for
        # end for
        return
示例#4
0
    def collect_data(cls, **options) -> NoReturn:
        data_mgr = FilesManager(cls.dataset_dir)

        task = options["task"]

        projects_path = Path(options.get("corpus", cls.dataset_dir / "projects-standalone-8.10.yml"))
        projects: List[Project] = IOUtils.dejsonfy(IOUtils.load(projects_path, "json"), Project)

        if task == cls.TASK_COQ_DOCUMENTS:
            files = Utils.get_option_as_list(options, "files", None)
            is_verifying_tokenizer = Utils.get_option_as_boolean(options, "verify-tokenizer")
            cls.collect_coq_documents_projects(data_mgr, projects, files, is_verifying_tokenizer)
        elif task == cls.TASK_DATA_INDEXES:
            cls.collect_data_indexes(data_mgr, projects)
        elif task == cls.TASK_DEFINITIONS:
            cls.collect_definitions(data_mgr)
        elif task == cls.TASK_INSTALL_COQ_PROJECTS:
            cls.install_coq_projects(projects)
        elif task == cls.TASK_LEMMA:
            files = Utils.get_option_as_list(options, "files", None)
            cls.collect_lemmas(data_mgr, projects, files)
        elif task == cls.TASK_LEMMA_BACKEND_SEXP_TRANSFORMATIONS:
            cls.collect_lemmas_backend_sexp_transformations(data_mgr)
        elif task == cls.TASK_LEMMA_FILTERED:
            cls.filter_lemmas(data_mgr)
        elif task == cls.TASK_LEMMA_FOREEND_SEXP_TRANSFORMATIONS:
            cls.collect_lemmas_foreend_sexp_transformations(data_mgr)
        else:
            LoggingUtils.log_and_raise(cls.logger, f"Unknown task {task}", ValueError)
        # end if
        return
示例#5
0
def normalize_options(opts: dict) -> dict:
    # Set a different log file
    if "log_path" in opts:
        logger.info(f"Switching to log file {opts['log_path']}")
        LoggingUtils.setup(filename=opts['log_path'])
    # end if

    # Set debug mode
    if "debug" in opts and str(opts["debug"]).lower() != "false":
        Environment.is_debug = True
        logger.debug("Debug mode on")
        logger.debug(f"Command line options: {opts}")
    # end if

    # Set parallel mode - all automatic installations are disabled
    if "parallel" in opts and str(opts["parallel"]).lower() != "false":
        Environment.is_parallel = True
        logger.warning(f"Parallel mode on")
    # end if

    # Set/report random seed
    if "random_seed" in opts:
        Environment.random_seed = int(opts["random_seed"])
    else:
        Environment.random_seed = time.time_ns()
    # end if
    random.seed(Environment.random_seed)
    logger.info(f"Random seed is {Environment.random_seed}")

    # Automatically update data and results repo
    Environment.require_data()
    Environment.require_results()
    return opts
示例#6
0
文件: main.py 项目: mfkiwl/hdlp
def normalize_options(opts: dict) -> dict:
    if "log-path" in opts:
        logger.info(f"Switching to log file {opts['log-path']}")
        LoggingUtils.setup(filename=opts['log-path'])
    # end if
    if "random-seed" in opts:
        random.seed(opts["random-seed"])
    else:
        seed = time.time_ns()
        random.seed(seed)
        logger.info(f"Random seed is {seed}")
    # end if
    return opts
示例#7
0
    def collect_definitions(cls, data_mgr: FilesManager):
        data_mgr.clean_path([FilesManager.DEFINITIONS])
        data_mgr.resolve([FilesManager.DEFINITIONS]).mkdir(parents=True)

        # Load coq-documents
        coq_documents: List[CoqDocument] = cls.load_coq_documents(data_mgr)

        definitions: List[Definition] = list()

        errors: List[Tuple[str, str]] = list()

        for doc_i, doc in enumerate(tqdm(coq_documents)):
            try:
                # Load AST sexp
                ast_sexp_list: List[SexpNode] = SexpParser.parse_list(
                    data_mgr.load_data([
                        FilesManager.RAW_FILES,
                        doc.get_data_index()[:-2] + ".ast.sexp"
                    ], IOUtils.Format.txt))
                definitions_doc: List[
                    Definition] = cls.collect_definitions_doc(
                        doc, ast_sexp_list)

                definitions.extend(definitions_doc)
            except KeyboardInterrupt:
                cls.logger.warning(f"Keyboard Interrupt!")
                raise
            except:
                cls.logger.warning(
                    f"Error while parsing {doc.get_data_index()}: {traceback.format_exc()}"
                )
                cls.logger.warning(
                    f"The script will continue on other files before it returns with failure. Use Ctrl+C to cut it early."
                )
                errors.append((doc.get_data_index(), traceback.format_exc()))
                continue
            # end try
        # end for

        if len(errors) > 0:
            LoggingUtils.log_and_raise(
                cls.logger,
                f"There were {len(errors)} errors during collection.",
                Exception)
            data_mgr.dump_data([FilesManager.DEFINITIONS, "errors.txt"],
                               errors, IOUtils.Format.jsonPretty)
        # end if

        data_mgr.dump_data([FilesManager.DEFINITIONS, "definitions.json"],
                           definitions, IOUtils.Format.json)
        return
示例#8
0
    def collect_lemmas(cls, data_mgr: FilesManager, projects: List[Project], files: List[str] = None):
        data_mgr.clean_path([FilesManager.LEMMAS])
        data_mgr.resolve([FilesManager.LEMMAS]).mkdir(parents=True)

        # Increase recursion limit because the backend sexps are CRAZZZZY deep
        sys.setrecursionlimit(10000)

        # Load coq-documents
        coq_documents: List[CoqDocument] = cls.load_coq_documents(data_mgr)
        if files is not None:  coq_documents = [d for d in coq_documents if d.file_name in files]

        lemmas: List[Lemma] = list()

        # Prepare serapi_options
        project_2_serapi_options: Dict[str, str] = {p.full_name: p.data["serapi_options"] for p in projects}

        errors: List[Tuple[str, str]] = list()

        for doc_i, doc in enumerate(tqdm(coq_documents)):
            try:
                cls.logger.info(f"Collecting from file {doc.get_data_index()} ({doc_i}/{len(coq_documents)}). Collected: {len(lemmas)}")

                # Load AST sexp
                ast_sexp_list: List[SexpNode] = SexpParser.parse_list(data_mgr.load_data([FilesManager.RAW_FILES, doc.get_data_index()[:-2] + ".ast.sexp"], IOUtils.Format.txt))

                # Collect lemmas from this doc
                lemmas_doc: List[Lemma] = cls.collect_lemmas_doc(doc, ast_sexp_list, project_2_serapi_options[doc.project_name])
                lemmas.extend(lemmas_doc)
            except KeyboardInterrupt:
                cls.logger.warning(f"Keyboard Interrupt!")
                raise
            except:
                cls.logger.warning(f"Error while parsing {doc.get_data_index()}: {traceback.format_exc()}")
                cls.logger.warning(f"The script will continue on other files before it returns with failure. Use Ctrl+C to cut it early.")
                errors.append((doc.get_data_index(), traceback.format_exc()))
                continue
            # end try
        # end for

        if len(errors) > 0:
            LoggingUtils.log_and_raise(cls.logger, f"There were {len(errors)} errors during collection.", Exception)
            data_mgr.dump_data([FilesManager.LEMMAS, "errors.txt"], errors, IOUtils.Format.jsonPretty)
        # end if

        # Assign uids
        for lemma_i, lemma in enumerate(lemmas):  lemma.uid = lemma_i

        data_mgr.dump_data([FilesManager.LEMMAS], lemmas, IOUtils.Format.json, is_batched=True, per_batch=5000)
        return
示例#9
0
class Environment:

    logger = LoggingUtils.get_logger(__name__)

    # =====
    # Random seed
    random_seed: int = None
示例#10
0
def dispatch(
    argv: List[str],
    targets: Dict[str, Callable],
    default_target: str = "main",
):
    """
    Dispatches the arguments to one of the targets.  The target name should be specified as the first free argument,
    or the default target is used.
    :param argv:
    :param targets:
    :param default_target:
    :return:
    """
    logger = LoggingUtils.get_logger("args.main")
    logger.info("Starting")

    args = parse(argv)
    if len(args.free) > 0:
        target = args.free[0]
        args = args._replace(free=args.free[1:])
    else:
        target = default_target

    if target not in targets:
        raise RuntimeError(
            f"Cannot find target {target} in the available set of targets: {targets.keys()}"
        )
    else:
        f = targets[target]
        sig = inspect.Signature.from_callable(f)
        bounded_args = args.fill_signature(sig)
        ret = f(*bounded_args.args, **bounded_args.kwargs)
        logger.info("Terminating")
        return ret
示例#11
0
文件: Vocabulary.py 项目: mfkiwl/hdlp
class VocabularyBuilder(Generic[VocabT]):
    """
    Builds a vocabulary, e.g., on training data, with counting the frequency of each word and do some filtering based on that.

    Although not recommended, it is possible to do wired things via changing #counter_words before generating the vocabulary.
    """
    logger = LoggingUtils.get_logger(__name__)

    def __init__(self, pad_token: VocabT, unk_token: VocabT):
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.counter_words: Counter[VocabT] = collections.Counter()
        self.secured_words: Set[VocabT] = set()

        self.secure_word(pad_token)
        self.secure_word(unk_token)
        return

    def add_word(self, word: VocabT):
        if word not in self.secured_words:
            self.counter_words[word] += 1
        # end if
        return

    def secure_word(self, word: VocabT):
        """
        Secures word to make it definitely appear in the final vocab, ignoring frequency_threshold filtering, and breaks max_size limit if necessary.

        PAD and UNK are automatically secured.
        """
        self.secured_words.add(word)
        # Clear it from the counter so that it doesn't compete with other words
        self.counter_words[word] = -1
        return

    def build(self,
              frequency_threshold: int = 0,
              max_size: Optional[int] = None) -> Vocabulary[VocabT]:
        """
        Builds the vocabulary based on the counter in this builder, and according to filtering arguments.
        :param frequency_threshold: a word need to appear at least such times to be in the vocabulary. Default is 0.
        :param max_size: the maximum size of the vocabulary (including all secured words, e.g., PAD and UNK). None means no limit.
        If the number of secured words exceeds max_size, the built vocabulary will contain (only) the secured words ignoring the limit.
        :return: the built vocabulary.
        """
        selected_words = sorted([
            w
            for w, c in self.counter_words.items() if c >= frequency_threshold
        ],
                                key=lambda w: self.counter_words[w],
                                reverse=True)
        if max_size is not None:
            selected_words = selected_words[:max_size -
                                            len(self.secured_words)]
        vocab: Vocabulary[VocabT] = Vocabulary(self.pad_token, self.unk_token)
        for word in self.secured_words:
            vocab.add_word(word)
        for word in selected_words:
            vocab.add_word(word)
        return vocab
示例#12
0
class TransformerProcessor:

    logger = LoggingUtils.get_logger(
        __name__,
        LoggingUtils.DEBUG if Environment.is_debug else LoggingUtils.INFO)

    def __init__(self):
        super(TransformerProcessor, self).__init__()
        return

    def process_data(self, model_data_dir: Path, data_prefix: str):
        """
        Assume we have the raw data file generated bu Bi-LSTM model processor: src-train.txt, tgt-train.txt, src-val.txt, tgt-val.txt
        :param model_data_dir: the dir for storing the data for transformer.
        :param data_prefix: e.g. evo-2020, mixedproj-2020
        :return:
        """
        self.logger.info(f"Start processing")

        BashUtils.run(
            f"onmt_preprocess -train_src {model_data_dir}/{data_prefix}-{Macros.train}/src-train.txt "
            f"-train_tgt {model_data_dir}/{data_prefix}-{Macros.train}/tgt-train.txt "
            f"-valid_src {model_data_dir}/{data_prefix}-{Macros.val}/src-val.txt "
            f"-valid_tgt {model_data_dir}/{data_prefix}-{Macros.val}/tgt-val.txt "
            f"-save_data {model_data_dir}/{data_prefix}-{Macros.train}/transformer --src_seq_length 200"
            f" --src_seq_length_trunc 200 --shard_size 0",
            expected_return_code=0)
示例#13
0
class MultiSourceNMTModel(nn.Module):

    logger = LoggingUtils.get_logger(__name__)

    def __init__(self,
            encoders: List[EncoderBase],
            decoder: DecoderBase,
    ):
        super().__init__()
        self.encoders = encoders
        for enc_i, encoder in enumerate(self.encoders):  self.add_module(f"encoder-{enc_i}", encoder)
        self.decoder = decoder
        return

    def forward(self,
            src_list: List[torch.Tensor],
            tgt: torch.LongTensor,
            lengths_list: List[torch.LongTensor],
            bptt: bool = False,
    ) -> Tuple[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
        """Forward propagate a `src` and `tgt` pair for training.
        Possible initialized with a beginning decoder state.

        Args:
            src (Tensor): A source sequence passed to encoder.
                typically for inputs this will be a padded `LongTensor`
                of size ``(len, batch, features)``. However, may be an
                image or other generic input depending on encoder.
            tgt (LongTensor): A target sequence of size ``(tgt_len, batch)``.
            lengths(LongTensor): The src lengths, pre-padding ``(batch,)``.
            bptt (Boolean): A flag indicating if truncated bptt is set.
                If reset then init_state

        Returns:
            (FloatTensor, dict[str, FloatTensor]):

            * decoder output ``(tgt_len, batch, hidden)``
            * dictionary attention dists of ``(tgt_len, batch, src_len)``
        """
        tgt = tgt[:-1]  # exclude last target from inputs

        enc_state_list: List = list()
        memory_bank_list: List = list()
        for enc_i, encoder in enumerate(self.encoders):
            enc_state, memory_bank, lengths = encoder(src_list[enc_i], lengths_list[enc_i])
            enc_state_list.append(enc_state)
            memory_bank_list.append(memory_bank)
            lengths_list[enc_i] = lengths
        # end for

        if bptt is False:
            self.decoder.init_state(src_list, memory_bank_list, enc_state_list)
        # end if
        dec_out, attns = self.decoder(tgt, memory_bank_list, memory_lengths_list=lengths_list)
        return dec_out, attns

    def update_dropout(self, dropout):
        self.encoder.update_dropout(dropout)
        self.decoder.update_dropout(dropout)
示例#14
0
    def load_ckpt(self, rel_path: Union[str, List[str]],
            load_func: Callable[[str], Any],
            ckpt_id: Optional[int] = None,
    ) -> Any:
        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        if not abs_path.exists():
            LoggingUtils.log_and_raise(self.logger, f"Cannot find data at {abs_path}", IOError)
        # end if

        if ckpt_id is None:
            # Find the latest ckpt
            ckpt_ids = [int(str(f.name)) for f in abs_path.iterdir()]
            ckpt_id = max(ckpt_ids)
            self.logger.info(f"Loading the latest checkpoint {ckpt_id} at {abs_path}")
        # end if

        return load_func(str(abs_path / str(ckpt_id)))
示例#15
0
    def require_special_repo(cls, directory: Path, branch: str):
        cls.logger.info(f"Updating {directory} to {branch} branch")
        if directory.exists():
            if not directory.is_dir() or not (directory / ".git").is_dir():
                LoggingUtils.log_and_raise(
                    cls.logger,
                    f"Path {directory} already exists but is not a proper git repository!",
                    Exception)
            # end if

            with IOUtils.cd(directory):
                BashUtils.run(f"git pull", expected_return_code=0)
            # end with
        else:
            IOUtils.mk_dir(directory)
            with IOUtils.cd(directory):
                BashUtils.run(
                    f"git clone --single-branch -b {branch} -- {cls.get_git_url()} .",
                    expected_return_code=0)
示例#16
0
def normalize_options(opts: dict) -> dict:
    if "log_path" in opts:
        logger.info(f"Switching to log file {opts['log_path']}")
        LoggingUtils.setup(filename=opts['log_path'])
    # end if

    if "random_seed" in opts:
        seed = opts["random_seed"]
    else:
        seed = time.time_ns()
    # end if
    logger.info(f"Random seed is {seed}")
    random.seed(seed)
    Environment.random_seed = seed

    if "debug" in opts:
        from roosterize.Debug import Debug
        Debug.is_debug = True
        logger.debug(f"options: {opts}")
    # end if
    return opts
示例#17
0
    def iter_batched_data(self,
            rel_path: Union[str, List[str]],
            fmt: IOUtils.Format,
            clz = None,
    ) -> Iterator:
        if self.is_json_format(fmt) and clz is None:
            self.logger.warning(f"Load data from {rel_path} with json format, but did not specify clz")
        # end if

        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        if not abs_path.exists():
            LoggingUtils.log_and_raise(self.logger, f"Cannot find data at {abs_path}", IOError)
        # end if

        batch_numbers = sorted([int(str(f.stem).split("-")[1]) for f in abs_path.iterdir()])
        for batch_number in batch_numbers:
            batch_file = abs_path / f"batch-{batch_number}.{fmt.get_extension()}"
            for data_entry in IOUtils.load(batch_file, fmt):
                if self.is_json_format(fmt) and clz is not None:
                    data_entry = IOUtils.dejsonfy(data_entry, clz)
                # end if
                yield data_entry
示例#18
0
    def get_joint_history(self):
        if len(self.training_history) != len(self.step_history):
            LoggingUtils.log_and_raise(self.logger, f"Cannot join two mismatch history!", Exception)
        # end if

        joint_history: List[dict] = list()
        for idx in range(len(self.training_history)):
            if self.training_history[idx]["step"] != self.step_history[idx]["step"]:
                LoggingUtils.log_and_raise(self.logger, f"Cannot join two mismatch history!", Exception)
            # end if
            joint_history.append({
                "step": self.training_history[idx]["step"],
                "elapsed_time": self.training_history[idx]["elapsed_time"],
                "learning_rate": self.training_history[idx]["learning_rate"],
                "train_accuracy": self.training_history[idx]["accuracy"],
                "train_ppl": self.training_history[idx]["ppl"],
                "train_xent": self.training_history[idx]["xent"],
                "val_accuracy": self.step_history[idx]["accuracy"],
                "val_ppl": self.step_history[idx]["ppl"],
                "val_xent": self.step_history[idx]["xent"],
            })
        # end for
        return joint_history
示例#19
0
class MultiSourceCopyGeneratorLoss(nn.Module):
    """Copy generator criterion."""

    logger = LoggingUtils.get_logger(__name__)

    def __init__(self,
                 vocab_size,
                 force_copy,
                 unk_index=0,
                 ignore_index=-100,
                 eps=1e-20):
        super().__init__()
        self.force_copy = force_copy
        self.eps = eps
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index
        self.unk_index = unk_index

    def forward(self, scores, align, target):
        """
        Args:
            scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size
                whose sum along dim 1 is less than or equal to 1, i.e. cols
                softmaxed.
            align (LongTensor): ``(batch_size x tgt_len)``
            target (LongTensor): ``(batch_size x tgt_len)``
        """
        # probabilities assigned by the model to the gold targets
        vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1)

        # probability of tokens copied from source
        copy_ix = align.unsqueeze(1) + self.vocab_size
        copy_tok_probs = scores.gather(1, copy_ix).squeeze(1)
        # Set scores for unk to 0 and add eps
        copy_tok_probs[align == self.unk_index] = 0
        copy_tok_probs += self.eps  # to avoid -inf logs

        # find the indices in which you do not use the copy mechanism
        non_copy = align == self.unk_index
        if not self.force_copy:
            non_copy = non_copy | (target != self.unk_index)

        probs = torch.where(non_copy, copy_tok_probs + vocab_probs,
                            copy_tok_probs)

        loss = -probs.log()  # just NLLLoss; can the module be incorporated?
        # Drop padding.
        loss[target == self.ignore_index] = 0
        return loss
示例#20
0
class AbstractFilter:

    logger = LoggingUtils.get_logger(
        __name__,
        LoggingUtils.DEBUG if Environment.is_debug else LoggingUtils.INFO)
    YEARS = [2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020]

    def __init__(self):
        return

    @abc.abstractmethod
    def process_data(self, project_dir: Path):
        """
        Processes the list of yearly method data to produce evo data.
        :param revision_file: the file name of originally collected project-revision file.
        :param method_file: the file name of method data.
        :param output_dir: the directory to put the processed data, prepared for this model.
        :param which: specify which filter function is used here.
        """
        raise NotImplementedError
示例#21
0
class AbstractProcessor:

    logger = LoggingUtils.get_logger(
        __name__,
        LoggingUtils.DEBUG if Environment.is_debug else LoggingUtils.INFO)

    def __init__(self):
        return

    @abc.abstractmethod
    def process_data(self,
                     method_data_list: List[dict],
                     data_type: str,
                     output_dir: Path,
                     traversal="None") -> List[int]:
        """
        Processes the list of method data, for the given data_type.
        :param method_data_list: list of MethodData
        :param data_type: the data_type (one of {train, val, test})
        :param output_dir: the directory to put the processed data, prepared for this model
        :return: the list of data indexes (in the method_data_list) that failed to process
        """
        raise NotImplementedError
示例#22
0
from csevo.Utils import Utils

# Check seutil version
EXPECTED_SEUTIL_VERSION = "0.4.12"
if pkg_resources.get_distribution("seutil").version != EXPECTED_SEUTIL_VERSION:
    print(
        f"seutil version does not meet expectation! Expected version: {EXPECTED_SEUTIL_VERSION}, current installed version: {pkg_resources.get_distribution('seutil').version}",
        file=sys.stderr)
    print(
        f"Hint: either upgrade seutil, or modify the expected version (after confirmation that the version will work)",
        file=sys.stderr)
    sys.exit(-1)
# end if

logging_file = Macros.python_dir / "experiment.log"
LoggingUtils.setup(filename=str(logging_file))

logger = LoggingUtils.get_logger(__name__)

# ==========
# Table & Plot


def make_tables(**options):
    from csevo.Table import Table
    which = Utils.get_option_as_list(options, "which")

    table_maker = Table()
    table_maker.make_tables(which, options)
    return
示例#23
0
class MultiSourceAPGlobalAttention(GlobalAttention):

    logger = LoggingUtils.get_logger(__name__)

    def forward(
        self,
        source: torch.FloatTensor,  # [batch, tgt_len, dim]
        memory_bank_list: List[
            torch.FloatTensor],  # [num_srcs] x [batch, src_len, dim]
        memory_lengths_list: List[
            torch.FloatTensor] = None,  # [num_srcs] x [batch]
        coverage=None
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        assert coverage is None

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False
        # end if

        # Join memory bank
        memory_bank = torch.cat(memory_bank_list, dim=1)

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths_list is not None:
            mask = torch.cat([
                sequence_mask(memory_lengths,
                              max_len=memory_bank_list[src_i].size(1))
                for src_i, memory_lengths in enumerate(memory_lengths_list)
            ],
                             dim=1)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))
        # end if

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l),
                                      -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l),
                                      -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)
        # end if

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)
        # end if

        return attn_h, align_vectors
示例#24
0
class BeamSearch(DecodeStrategy):
    """Generation beam search.

    Note that the attributes list is not exhaustive. Rather, it highlights
    tensors to document their shape. (Since the state variables' "batch"
    size decreases as beams finish, we denote this axis with a B rather than
    ``batch_size``).

    Args:
        beam_size (int): Number of beams to use (see base ``parallel_paths``).
        batch_size (int): See base.
        pad (int): See base.
        bos (int): See base.
        eos (int): See base.
        n_best (int): Don't stop until at least this many beams have
            reached EOS.
        mb_device (torch.device or str): See base ``device``.
        global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
        min_length (int): See base.
        max_length (int): See base.
        return_attention (bool): See base.
        block_ngram_repeat (int): See base.
        exclusion_tokens (set[int]): See base.
        memory_lengths (LongTensor): Lengths of encodings. Used for
            masking attentions.

    Attributes:
        top_beam_finished (ByteTensor): Shape ``(B,)``.
        _batch_offset (LongTensor): Shape ``(B,)``.
        _beam_offset (LongTensor): Shape ``(batch_size x beam_size,)``.
        alive_seq (LongTensor): See base.
        topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These
            are the scores used for the topk operation.
        select_indices (LongTensor or NoneType): Shape
            ``(B x beam_size,)``. This is just a flat view of the
            ``_batch_index``.
        topk_scores (FloatTensor): Shape
            ``(B, beam_size)``. These are the
            scores a sequence will receive if it finishes.
        topk_ids (LongTensor): Shape ``(B, beam_size)``. These are the
            word indices of the topk predictions.
        _batch_index (LongTensor): Shape ``(B, beam_size)``.
        _prev_penalty (FloatTensor or NoneType): Shape
            ``(B, beam_size)``. Initialized to ``None``.
        _coverage (FloatTensor or NoneType): Shape
            ``(1, B x beam_size, inp_seq_len)``.
        hypotheses (list[list[Tuple[Tensor]]]): Contains a tuple
            of score (float), sequence (long), and attention (float or None).
    """

    logger = LoggingUtils.get_logger(__name__)

    def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device,
                 global_scorer, min_length, max_length, return_attention,
                 block_ngram_repeat, exclusion_tokens, memory_lengths,
                 stepwise_penalty, ratio):
        super(BeamSearch,
              self).__init__(pad, bos, eos, batch_size, mb_device, beam_size,
                             min_length, block_ngram_repeat, exclusion_tokens,
                             return_attention, max_length)
        # beam parameters
        self.global_scorer = global_scorer
        self.beam_size = beam_size
        self.n_best = n_best
        self.batch_size = batch_size
        self.ratio = ratio

        # result caching
        self.hypotheses = [[] for _ in range(batch_size)]

        # beam state
        self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
        self.best_scores = torch.full([batch_size],
                                      -1e10,
                                      dtype=torch.float,
                                      device=mb_device)

        self._batch_offset = torch.arange(batch_size, dtype=torch.long)
        self._beam_offset = torch.arange(0,
                                         batch_size * beam_size,
                                         step=beam_size,
                                         dtype=torch.long,
                                         device=mb_device)
        self.topk_log_probs = torch.tensor([0.0] + [float("-inf")] *
                                           (beam_size - 1),
                                           device=mb_device).repeat(batch_size)
        self.select_indices = None
        self._memory_lengths = memory_lengths

        # buffers for the topk scores and 'backpointer'
        self.topk_scores = torch.empty((batch_size, beam_size),
                                       dtype=torch.float,
                                       device=mb_device)
        self.topk_ids = torch.empty((batch_size, beam_size),
                                    dtype=torch.long,
                                    device=mb_device)
        self._batch_index = torch.empty([batch_size, beam_size],
                                        dtype=torch.long,
                                        device=mb_device)
        self.done = False
        # "global state" of the old beam
        self._prev_penalty = None
        self._coverage = None

        self._stepwise_cov_pen = (stepwise_penalty
                                  and self.global_scorer.has_cov_pen)
        self._vanilla_cov_pen = (not stepwise_penalty
                                 and self.global_scorer.has_cov_pen)
        self._cov_pen = self.global_scorer.has_cov_pen

    @property
    def current_predictions(self):
        return self.alive_seq[:, -1]

    @property
    def current_origin(self):
        return self.select_indices

    @property
    def current_backptr(self):
        # for testing
        return self.select_indices.view(self.batch_size, self.beam_size)\
            .fmod(self.beam_size)

    def advance(self, log_probs, attn):
        vocab_size = log_probs.size(-1)

        # using integer division to get an integer _B without casting
        _B = log_probs.shape[0] // self.beam_size

        if self._stepwise_cov_pen and self._prev_penalty is not None:
            self.topk_log_probs += self._prev_penalty
            self.topk_log_probs -= self.global_scorer.cov_penalty(
                self._coverage + attn,
                self.global_scorer.beta).view(_B, self.beam_size)

        # force the output to be longer than self.min_length
        step = len(self)
        self.ensure_min_length(log_probs)

        # Multiply probs by the beam probability.
        log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)

        self.block_ngram_repeats(log_probs)

        # if the sequence ends now, then the penalty is the current
        # length + 1, to include the EOS token
        length_penalty = self.global_scorer.length_penalty(
            step + 1, alpha=self.global_scorer.alpha)

        # Flatten probs into a list of possibilities.
        curr_scores = log_probs / length_penalty
        curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size)
        torch.topk(curr_scores,
                   self.beam_size,
                   dim=-1,
                   out=(self.topk_scores, self.topk_ids))

        # Recover log probs.
        # Length penalty is just a scalar. It doesn't matter if it's applied
        # before or after the topk.
        torch.mul(self.topk_scores, length_penalty, out=self.topk_log_probs)

        # Resolve beam origin and map to batch index flat representation.
        torch.div(self.topk_ids, vocab_size, out=self._batch_index)
        self._batch_index += self._beam_offset[:_B].unsqueeze(1)
        self.select_indices = self._batch_index.view(_B * self.beam_size)

        self.topk_ids.fmod_(vocab_size)  # resolve true word ids

        # Append last prediction.
        self.alive_seq = torch.cat([
            self.alive_seq.index_select(0, self.select_indices),
            self.topk_ids.view(_B * self.beam_size, 1)
        ], -1)
        if self.return_attention or self._cov_pen:
            current_attn = attn.index_select(1, self.select_indices)
            if step == 1:
                self.alive_attn = current_attn
                # update global state (step == 1)
                if self._cov_pen:  # coverage penalty
                    self._prev_penalty = torch.zeros_like(self.topk_log_probs)
                    self._coverage = current_attn
            else:
                self.alive_attn = self.alive_attn.index_select(
                    1, self.select_indices)
                self.alive_attn = torch.cat([self.alive_attn, current_attn], 0)
                # update global state (step > 1)
                if self._cov_pen:
                    self._coverage = self._coverage.index_select(
                        1, self.select_indices)
                    self._coverage += current_attn
                    self._prev_penalty = self.global_scorer.cov_penalty(
                        self._coverage,
                        beta=self.global_scorer.beta).view(_B, self.beam_size)

        if self._vanilla_cov_pen:
            # shape: (batch_size x beam_size, 1)
            cov_penalty = self.global_scorer.cov_penalty(
                self._coverage, beta=self.global_scorer.beta)
            self.topk_scores -= cov_penalty.view(_B, self.beam_size)

        self.is_finished = self.topk_ids.eq(self.eos)
        self.ensure_max_length()

    def update_finished(self):
        # Penalize beams that finished.
        _B_old = self.topk_log_probs.shape[0]
        step = self.alive_seq.shape[-1]  # 1 greater than the step in advance
        self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
        # on real data (newstest2017) with the pretrained transformer,
        # it's faster to not move this back to the original device
        self.is_finished = self.is_finished.to('cpu')
        self.top_beam_finished |= self.is_finished[:, 0].eq(1)
        predictions = self.alive_seq.view(_B_old, self.beam_size, step)
        attention = (self.alive_attn.view(step - 1, _B_old, self.beam_size,
                                          self.alive_attn.size(-1))
                     if self.alive_attn is not None else None)
        non_finished_batch = []
        for i in range(self.is_finished.size(0)):
            b = self._batch_offset[i]
            finished_hyp = self.is_finished[i].nonzero().view(-1)
            # Store finished hypotheses for this batch.
            for j in finished_hyp:
                if self.ratio > 0:
                    s = self.topk_scores[i, j] / (step + 1)
                    if self.best_scores[b] < s:
                        self.best_scores[b] = s

                self.hypotheses[b].append((
                    self.topk_scores[i, j],
                    predictions[i, j, 1:],  # Ignore start_token.
                    attention[:, i, j, :self._memory_lengths[i]]
                    if attention is not None else None))
            # End condition is the top beam finished and we can return
            # n_best hypotheses.
            if self.ratio > 0:
                pred_len = self._memory_lengths[i] * self.ratio
                finish_flag = ((self.topk_scores[i, 0] / pred_len)
                               <= self.best_scores[b]) or \
                    self.is_finished[i].all()
            else:
                finish_flag = self.top_beam_finished[i] != 0
            if finish_flag and len(self.hypotheses[b]) >= self.n_best:
                best_hyp = sorted(self.hypotheses[b],
                                  key=lambda x: x[0],
                                  reverse=True)
                for n, (score, pred, attn) in enumerate(best_hyp):
                    if n >= self.n_best:
                        break
                    self.scores[b].append(score)
                    self.predictions[b].append(pred)
                    self.attention[b].append(attn if attn is not None else [])
            else:
                non_finished_batch.append(i)
        non_finished = torch.tensor(non_finished_batch)
        # If all sentences are translated, no need to go further.
        if len(non_finished) == 0:
            self.done = True
            return

        _B_new = non_finished.shape[0]
        # Remove finished batches for the next step.
        self.top_beam_finished = self.top_beam_finished.index_select(
            0, non_finished)
        self._batch_offset = self._batch_offset.index_select(0, non_finished)
        non_finished = non_finished.to(self.topk_ids.device)
        self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished)
        self._batch_index = self._batch_index.index_select(0, non_finished)
        self.select_indices = self._batch_index.view(_B_new * self.beam_size)
        self.alive_seq = predictions.index_select(0, non_finished) \
            .view(-1, self.alive_seq.size(-1))
        self.topk_scores = self.topk_scores.index_select(0, non_finished)
        self.topk_ids = self.topk_ids.index_select(0, non_finished)
        if self.alive_attn is not None:
            inp_seq_len = self.alive_attn.size(-1)
            self.alive_attn = attention.index_select(1, non_finished) \
                .view(step - 1, _B_new * self.beam_size, inp_seq_len)
            if self._cov_pen:
                self._coverage = self._coverage \
                    .view(1, _B_old, self.beam_size, inp_seq_len) \
                    .index_select(1, non_finished) \
                    .view(1, _B_new * self.beam_size, inp_seq_len)
                if self._stepwise_cov_pen:
                    self._prev_penalty = self._prev_penalty.index_select(
                        0, non_finished)
        return non_finished
示例#25
0
class MultiSourceTypeAppendedTranslator(CustomTranslator):

    logger = LoggingUtils.get_logger(__name__)

    @classmethod
    def build_translator(cls, src_types, opt, report_score=True, logger=None, out_file=None):
        if out_file is None:
            out_file = codecs.open(opt.output, 'w+', 'utf-8')
            
        assert len(opt.models) == 1, "ensemble model is not supported"

        # load_test_model = onmt.decoders.ensemble.load_test_model \
        #     if len(opt.models) > 1 else onmt.model_builder.load_test_model
        fields, model, model_opt = MultiSourceTypeAppendedModelBuilder.load_test_model(src_types, opt)

        scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt)

        translator = cls.from_opt(
            src_types,
            model,
            fields,
            opt,
            model_opt,
            global_scorer=scorer,
            out_file=out_file,
            report_score=report_score,
            logger=logger
        )
        return translator

    def __init__(self, src_types, model, fields, src_reader, tgt_reader, gpu=-1, n_best=1, min_length=0, max_length=100, ratio=0., beam_size=30, random_sampling_topk=1, random_sampling_temp=1, stepwise_penalty=None, dump_beam=False, block_ngram_repeat=0, ignore_when_blocking=frozenset(), replace_unk=False, phrase_table="", data_type="text", verbose=False, report_bleu=False, report_rouge=False, report_time=False, copy_attn=False, global_scorer=None, out_file=None, report_score=True, logger=None, use_reranker=False, seed=-1):
        super().__init__(model, fields, src_reader, tgt_reader, gpu, n_best, min_length, max_length, ratio, beam_size, random_sampling_topk, random_sampling_temp, stepwise_penalty, dump_beam, block_ngram_repeat, ignore_when_blocking, replace_unk, phrase_table, data_type, verbose, report_bleu, report_rouge, report_time, copy_attn, global_scorer, out_file, report_score, logger, seed)
        self.src_types = src_types
        self.logger = MultiSourceTypeAppendedModelBuilder.logger

        self.reranker = MultiSourceReranker(2*self.beam_size, self.src_types) if use_reranker else None
        return

    @classmethod
    def from_opt(cls,
            src_types,
            model,
            fields,
            opt,
            model_opt,
            global_scorer=None,
            out_file=None,
            report_score=True,
            logger=None):
        """Alternate constructor.

        Args:
            model (onmt.modules.NMTModel): See :func:`__init__()`.
            fields (dict[str, torchtext.data.Field]): See
                :func:`__init__()`.
            opt (argparse.Namespace): Command line options
            model_opt (argparse.Namespace): Command line options saved with
                the model checkpoint.
            global_scorer (onmt.translate.GNMTGlobalScorer): See
                :func:`__init__()`..
            out_file (TextIO or codecs.StreamReaderWriter): See
                :func:`__init__()`.
            report_score (bool) : See :func:`__init__()`.
            logger (logging.Logger or NoneType): See :func:`__init__()`.
        """

        src_reader = inputters.str2reader["text"].from_opt(opt)
        tgt_reader = inputters.str2reader["text"].from_opt(opt)
        return cls(
            src_types,
            model,
            fields,
            src_reader,
            tgt_reader,
            gpu=opt.gpu,
            n_best=opt.n_best,
            min_length=opt.min_length,
            max_length=opt.max_length,
            ratio=opt.ratio,
            beam_size=opt.beam_size,
            random_sampling_topk=opt.random_sampling_topk,
            random_sampling_temp=opt.random_sampling_temp,
            stepwise_penalty=opt.stepwise_penalty,
            dump_beam=opt.dump_beam,
            block_ngram_repeat=opt.block_ngram_repeat,
            ignore_when_blocking=set(opt.ignore_when_blocking),
            replace_unk=opt.replace_unk,
            phrase_table=opt.phrase_table,
            data_type=opt.data_type,
            verbose=opt.verbose,
            report_bleu=opt.report_bleu,
            report_rouge=opt.report_rouge,
            report_time=opt.report_time,
            copy_attn=model_opt.copy_attn,
            global_scorer=global_scorer,
            out_file=out_file,
            report_score=report_score,
            logger=logger,
            use_reranker=opt.use_reranker,
            seed=opt.seed)

    def translate(self,
            raw_data_shard: Dict,
            has_target: bool,
            src_dir=None,
            batch_size=None,
            attn_debug=False,
            phrase_table=""):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        candidates_logprobs: List[List[Tuple[List[str], float]]] = list()

        if batch_size is None:
            raise ValueError("batch_size must be set")

        raw_data_keys = [f"src.{src_type}" for src_type in self.src_types] + (["tgt"] if has_target else [])
        tgt = raw_data_shard.get("tgt")

        data = MultiSourceDataset(
            self.src_types,
            self.fields,
            readers=([self.src_reader] * len(self.src_types) + ([self.tgt_reader] if self.tgt_reader else [])),
            data=[(k, raw_data_shard[k]) for k in raw_data_keys],
            dirs=[None] * len(raw_data_keys),
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred,
            can_copy=self.copy_attn,
        )
        
        data_iter = MultiSourceInputter.OrderedIterator(
            src_types=self.src_types,
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False
        )

        xlation_builder = MultiSourceTranslationBuilder(
            self.src_types,
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table
        )

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            batch_data = self.translate_batch(
                batch, data.src_vocabs, attn_debug, xlation_builder,
            )
            translations = xlation_builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [" ".join(pred)
                               for pred in trans.pred_sents[:self.n_best]]
                all_predictions += [n_best_preds]

                candidates_logprobs.append([
                    (trans.pred_sents[idx], trans.pred_scores[idx].item())
                    for idx in range(self.n_best)
                ])
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    #srcs = list(chain.from_iterable(srcs)) 
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'                    
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    self._log(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" % (
                total_time / len(all_predictions)))
            self._log("Tokens per second: %f" % (
                pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions, candidates_logprobs

    def translate_batch(self, batch, src_vocabs, attn_debug, xlation_builder):
        """Translate a batch of sentences."""
        with torch.no_grad():
            if self.beam_size == 1:
                return self._translate_random_sampling(
                    batch,
                    src_vocabs,
                    self.max_length,
                    min_length=self.min_length,
                    sampling_temp=self.random_sampling_temp,
                    keep_topk=self.sample_from_topk,
                    return_attention=attn_debug or self.replace_unk)
            else:
                return self._translate_batch(
                    batch,
                    src_vocabs,
                    self.max_length,
                    min_length=self.min_length,
                    ratio=self.ratio,
                    n_best=self.n_best,
                    return_attention=attn_debug or self.replace_unk,
                    xlation_builder=xlation_builder,
                )

    def _translate_batch(
            self,
            batch,
            src_vocabs,
            max_length,
            min_length=0,
            ratio=0.,
            n_best=1,
            return_attention=False,
            xlation_builder=None,
    ):
        # TODO: support these blacklisted features.
        assert not self.dump_beam

        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        beam_size = self.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src_list, enc_states_list, memory_bank_list, src_lengths_list = self._run_encoder(batch)
        self.model.decoder.init_state(src_list, memory_bank_list, enc_states_list)

        results = {
            "predictions": None,
            "scores": None,
            "attention": None,
            "batch": batch,
            "gold_score": self._gold_score(
                batch, memory_bank_list, src_lengths_list, src_vocabs, use_src_map,
                enc_states_list, batch_size, src_list)}
        
        # (2) Repeat src objects `beam_size` times.
        # We use batch_size x beam_size
        src_map_list = list()
        for src_type in self.src_types:
            src_map_list.append((tile(getattr(batch, f"src_map.{src_type}"), beam_size, dim=1) if use_src_map else None))
        # end for

        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        memory_lengths_list = list()
        memory_lengths = list()
        for src_i in range(len(memory_bank_list)):
            if isinstance(memory_bank_list[src_i], tuple):
                memory_bank_list[src_i] = tuple(tile(x, beam_size, dim=1) for x in memory_bank_list[src_i])
                mb_device = memory_bank_list[src_i][0].device
            else:
                memory_bank_list[src_i] = tile(memory_bank_list[src_i], beam_size, dim=1)
                mb_device = memory_bank_list[src_i].device
            # end if
            memory_lengths_list.append(tile(src_lengths_list[src_i], beam_size))
            memory_lengths.append(src_lengths_list[src_i])
        # end for
        memory_lengths = tile(torch.stack(memory_lengths, dim=0).sum(dim=0), beam_size)

        indexes = tile(torch.tensor(list(range(batch_size)), device=self._dev), beam_size)

        # (0) pt 2, prep the beam object
        beam = BeamSearch(
            beam_size,
            n_best=n_best,
            batch_size=batch_size,
            global_scorer=self.global_scorer,
            pad=self._tgt_pad_idx,
            eos=self._tgt_eos_idx,
            bos=self._tgt_bos_idx,
            min_length=min_length,
            ratio=ratio,
            max_length=max_length,
            mb_device=mb_device,
            return_attention=return_attention,
            stepwise_penalty=self.stepwise_penalty,
            block_ngram_repeat=self.block_ngram_repeat,
            exclusion_tokens=self._exclusion_idxs,
            memory_lengths=memory_lengths)

        for step in range(max_length):
            decoder_input = beam.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank_list,
                batch,
                src_vocabs,
                memory_lengths_list=memory_lengths_list,
                src_map_list=src_map_list,
                step=step,
                batch_offset=beam._batch_offset)

            if self.reranker is not None:
                log_probs = self.reranker.rerank_step_beam_batch(
                    batch,
                    beam,
                    self.beam_size,
                    indexes,
                    log_probs,
                    attn,
                    self.fields["tgt"].base_field.vocab,
                    xlation_builder,
                )
            # end if

            non_finished = None
            beam.advance(log_probs, attn)
            any_beam_is_finished = beam.is_finished.any()
            if any_beam_is_finished:
                non_finished = beam.update_finished()
                if beam.done:
                    break

            select_indices = beam.current_origin

            if any_beam_is_finished:
                # Reorder states.
                for src_i in range(len(memory_bank_list)):
                    if isinstance(memory_bank_list[src_i], tuple):
                        memory_bank_list[src_i] = tuple(x.index_select(1, select_indices)
                                            for x in memory_bank_list[src_i])
                    else:
                        memory_bank_list[src_i] = memory_bank_list[src_i].index_select(1, select_indices)
                    # end if

                    memory_lengths_list[src_i] = memory_lengths_list[src_i].index_select(0, select_indices)
                # end for

                if use_src_map and src_map_list[0] is not None:
                    for src_i in range(len(src_map_list)):
                        src_map_list[src_i] = src_map_list[src_i].index_select(1, select_indices)
                    # end for
                # end if

                indexes = indexes.index_select(0, select_indices)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        results["scores"] = beam.scores
        results["predictions"] = beam.predictions
        results["attention"] = beam.attention
        return results

    def _run_encoder(self, batch):
        src_list = dict()
        src_lengths_list = dict()
        for src_type in self.src_types:
            batch_src = getattr(batch, f"src.{src_type}")
            src, src_lengths = batch_src if isinstance(batch_src, tuple) else (batch_src, None)
            src_list[src_type] = src
            src_lengths_list[src_type] = src_lengths
        # end for

        enc_states_list = list()
        memory_bank_list = list()
        new_src_lengths_list = list()
        for enc_type, encoder in self.model.encoders.items():
            if enc_type=="l":
                enc_states, memory_bank, src_lengths = encoder(src_list[enc_type], src_list["type"], src_lengths_list[enc_type])
            else:
                enc_states, memory_bank, src_lengths = encoder(src_list[enc_type], src_list["patype"], src_lengths_list[enc_type])
            if src_lengths is None:
                assert not isinstance(memory_bank, tuple), \
                    'Ensemble decoding only supported for text data'
                src_lengths = torch.Tensor(batch.batch_size) \
                                   .type_as(memory_bank) \
                                   .long() \
                                   .fill_(memory_bank.size(0))
            # end if
            enc_states_list.append(enc_states)
            memory_bank_list.append(memory_bank)
            new_src_lengths_list.append(src_lengths)
        # end for
        return src_list, enc_states_list, memory_bank_list, new_src_lengths_list

    def _decode_and_generate(
            self,
            decoder_in,
            memory_bank_list,
            batch,
            src_vocabs,
            memory_lengths_list,
            src_map_list=None,
            step=None,
            batch_offset=None):
        if self.copy_attn:
            # Turn any copied words into UNKs.
            decoder_in = decoder_in.masked_fill(
                decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
            )

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(
            decoder_in, memory_bank_list, memory_lengths_list=memory_lengths_list, step=step
        )


        # Generator forward.
        if not self.copy_attn:
            if "std" in dec_attn:
                attn = dec_attn["std"]
            else:
                attn = None
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map_list)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = collapse_copy_scores(
                scores,
                batch,
                self._tgt_vocab,
                src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset
            )
            scores = scores.view(decoder_in.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
                
        return log_probs, attn
    
    def _score_target(self, batch, memory_bank_list, src_lengths_list,
                      src_vocabs, src_map_list):
        tgt = batch.tgt
        tgt_in = tgt[:-1]

        log_probs, attn = self._decode_and_generate(
            tgt_in, memory_bank_list, batch, src_vocabs,
            memory_lengths_list=src_lengths_list, src_map_list=src_map_list)

        log_probs[:, :, self._tgt_pad_idx] = 0
        gold = tgt[1:]
        gold_scores = log_probs.gather(2, gold)
        gold_scores = gold_scores.sum(dim=0).view(-1)

        return gold_scores

    def _gold_score(self, batch, memory_bank, src_lengths, src_vocabs,
                    use_src_map, enc_states, batch_size, src):
        if "tgt" in batch.__dict__:
            if use_src_map:
                src_map_list = list()
                for src_type in self.src_types:  src_map_list.append(getattr(batch, f"src_map.{src_type}"))
            else:
                src_map_list = None
            # end if
            gs = self._score_target(
                batch, memory_bank, src_lengths, src_vocabs,
                src_map_list)
            self.model.decoder.init_state(src, memory_bank, enc_states)
        else:
            gs = [0] * batch_size
        return gs
示例#26
0
class MultiSourceAPTranslationBuilder:

    logger = LoggingUtils.get_logger(__name__)

    def __init__(self,
                 src_types,
                 data,
                 fields,
                 n_best=1,
                 replace_unk=False,
                 has_tgt=False,
                 phrase_table=""):
        self.src_types = src_types
        self.data = data
        self.fields = fields
        self._has_text_src = True  # PN: all text for now
        self.n_best = n_best
        self.replace_unk = replace_unk
        self.phrase_table = phrase_table
        self.has_tgt = has_tgt

    def _build_target_tokens(self, src_list, src_vocab, src_raw, pred, attn):
        # PN: Flatten the src_list and src_raw (w/ padding)
        src_raw = [
            one_src_raw + [None] * (len(src_list[src_i]) - len(one_src_raw))
            for src_i, one_src_raw in enumerate(src_raw)
        ]
        src_raw = [x for one_src_raw in src_raw for x in one_src_raw]
        src_list = [x for one_src_list in src_list for x in one_src_list]

        tgt_field = dict(self.fields)["tgt"].base_field
        vocab = tgt_field.vocab
        tokens = []
        for tok in pred:
            if tok < len(vocab):
                tokens.append(vocab.itos[tok])
            else:
                tokens.append(src_vocab.itos[tok - len(vocab)])
            # end if
            if tokens[-1] == tgt_field.eos_token:
                tokens = tokens[:-1]
                break
        if self.replace_unk and attn is not None and src_list[0] is not None:
            for i in range(len(tokens)):
                if tokens[i] == tgt_field.unk_token:
                    _, max_index = attn[i][:len(src_raw)].max(0)
                    tokens[i] = src_raw[max_index.item()]
                    if self.phrase_table != "":
                        with open(self.phrase_table, "r") as f:
                            for line in f:
                                if line.startswith(src_raw[max_index.item()]):
                                    tokens[i] = line.split('|||')[1].strip()
        return tokens

    def from_batch(self, translation_batch):
        batch = translation_batch["batch"]
        assert (len(translation_batch["gold_score"]) == len(
            translation_batch["predictions"]))
        batch_size = batch.batch_size

        preds, pred_score, attn, gold_score, indices = list(
            zip(*sorted(zip(
                translation_batch["predictions"], translation_batch["scores"],
                translation_batch["attention"],
                translation_batch["gold_score"], batch.indices.data),
                        key=lambda x: x[-1])))

        # Sorting
        inds, perm = torch.sort(batch.indices)
        src_list = list()
        if self._has_text_src:
            for src_type in self.src_types:
                src_list.append(
                    getattr(batch,
                            f"src.{src_type}")[0][:, :,
                                                  0].index_select(1, perm))
        else:
            src_list = [None] * len(self.src_types)
        # end if
        tgt = batch.tgt[:, :, 0].index_select(1, perm) \
            if self.has_tgt else None

        translations = []
        for b in range(batch_size):
            src_raw_list = list()
            if self._has_text_src:
                src_vocab = self.data.src_vocabs[inds[b]] \
                    if self.data.src_vocabs else None
                for src_type in self.src_types:
                    src_raw_list.append(
                        getattr(self.data.examples[inds[b]],
                                f"src.{src_type}")[0])
                # end for
            else:
                src_vocab = [None] * len(self.src_types)
                src_raw_list = [None]
            pred_sents = [
                self._build_target_tokens([
                    src[:, b] if src is not None else None for src in src_list
                ], src_vocab, src_raw_list, preds[b][n], attn[b][n])
                for n in range(self.n_best)
            ]
            gold_sent = None
            if tgt is not None:
                gold_sent = self._build_target_tokens(
                    [
                        src[:, b] if src is not None else None
                        for src in src_list
                    ], src_vocab, src_raw_list,
                    tgt[1:, b] if tgt is not None else None, None)

            translation = MultiSourceAPTranslation(
                [src[:, b] if src is not None else None
                 for src in src_list], src_raw_list, pred_sents, attn[b],
                pred_score[b], gold_sent, gold_score[b])
            translations.append(translation)

        return translations
示例#27
0
class MultiSourceAPTrainer(CustomTrainer):

    logger = LoggingUtils.get_logger(__name__)

    def __init__(self,
                 src_types,
                 model,
                 train_loss,
                 valid_loss,
                 optim,
                 trunc_size=0,
                 shard_size=32,
                 norm_method="sents",
                 accum_count=[1],
                 accum_steps=[0],
                 n_gpu=1,
                 gpu_rank=1,
                 gpu_verbose_level=0,
                 report_manager=None,
                 model_saver=None,
                 average_decay=0,
                 average_every=1,
                 model_dtype='fp32',
                 earlystopper=None,
                 dropout=[0.3],
                 dropout_steps=[0]):
        super().__init__(model, train_loss, valid_loss, optim, trunc_size,
                         shard_size, norm_method, accum_count, accum_steps,
                         n_gpu, gpu_rank, gpu_verbose_level, report_manager,
                         model_saver, average_decay, average_every,
                         model_dtype, earlystopper, dropout, dropout_steps)

        self.src_types = src_types

        return

    @classmethod
    def build_loss_compute(cls, src_types, model, tgt_field, opt, train=True):
        """
        Returns a LossCompute subclass which wraps around an nn.Module subclass
        (such as nn.NLLLoss) which defines the loss criterion. The LossCompute
        object allows this loss to be computed in shards and passes the relevant
        data to a Statistics object which handles training/validation logging.
        Currently, the NMTLossCompute class handles all loss computation except
        for when using a copy mechanism.
        """
        device = torch.device(
            "cuda" if onmt.utils.misc.use_gpu(opt) else "cpu")

        padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token]
        unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token]

        if opt.lambda_coverage != 0:
            assert opt.coverage_attn, "--coverage_attn needs to be set in " \
                                      "order to use --lambda_coverage != 0"

        if opt.copy_attn:
            criterion = onmt.modules.CopyGeneratorLoss(
                len(tgt_field.vocab),
                opt.copy_attn_force,
                unk_index=unk_idx,
                ignore_index=padding_idx)
        elif opt.label_smoothing > 0 and train:
            criterion = LabelSmoothingLoss(opt.label_smoothing,
                                           len(tgt_field.vocab),
                                           ignore_index=padding_idx)
        elif isinstance(model.generator[-1], LogSparsemax):
            criterion = SparsemaxLoss(ignore_index=padding_idx,
                                      reduction='sum')
        else:
            criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum')

        # if the loss function operates on vectors of raw logits instead of
        # probabilities, only the first part of the generator needs to be
        # passed to the NMTLossCompute. At the moment, the only supported
        # loss function of this kind is the sparsemax loss.
        use_raw_logits = isinstance(criterion, SparsemaxLoss)
        loss_gen = model.generator[0] if use_raw_logits else model.generator
        if opt.copy_attn:
            compute = MultiSourceAPCopyGeneratorLossCompute(
                src_types,
                criterion,
                loss_gen,
                tgt_field.vocab,
                opt.copy_loss_by_seqlength,
                lambda_coverage=opt.lambda_coverage)
        else:
            compute = NMTLossCompute(criterion,
                                     loss_gen,
                                     lambda_coverage=opt.lambda_coverage)
        compute.to(device)

        return compute

    @classmethod
    def build_trainer(cls,
                      src_types,
                      opt,
                      device_id,
                      model,
                      fields,
                      optim,
                      model_saver=None):
        """
        Simplify `Trainer` creation based on user `opt`s*

        Args:
            opt (:obj:`Namespace`): user options (usually from argument parsing)
            model (:obj:`onmt.models.NMTModel`): the model to train
            fields (dict): dict of fields
            optim (:obj:`onmt.utils.Optimizer`): optimizer used during training
            data_type (str): string describing the type of data
                e.g. "text", "img", "audio"
            model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object
                used to save the model
        """

        tgt_field = dict(fields)["tgt"].base_field
        train_loss = cls.build_loss_compute(src_types, model, tgt_field, opt)
        valid_loss = cls.build_loss_compute(src_types,
                                            model,
                                            tgt_field,
                                            opt,
                                            train=False)

        trunc_size = opt.truncated_decoder  # Badly named...
        shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0
        norm_method = opt.normalization
        accum_count = opt.accum_count
        accum_steps = opt.accum_steps
        n_gpu = opt.world_size
        average_decay = opt.average_decay
        average_every = opt.average_every
        dropout = opt.dropout
        dropout_steps = opt.dropout_steps
        if device_id >= 0:
            gpu_rank = opt.gpu_ranks[device_id]
        else:
            gpu_rank = 0
            n_gpu = 0
        gpu_verbose_level = opt.gpu_verbose_level

        earlystopper = onmt.utils.EarlyStopping(
            opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt)) \
            if opt.early_stopping > 0 else None

        # Customized report manager
        report_manager = CustomReportMgr(opt.report_every, start_time=-1)

        trainer = cls(src_types,
                      model,
                      train_loss,
                      valid_loss,
                      optim,
                      trunc_size,
                      shard_size,
                      norm_method,
                      accum_count,
                      accum_steps,
                      n_gpu,
                      gpu_rank,
                      gpu_verbose_level,
                      report_manager,
                      model_saver=model_saver if gpu_rank == 0 else None,
                      average_decay=average_decay,
                      average_every=average_every,
                      model_dtype=opt.model_dtype,
                      earlystopper=earlystopper,
                      dropout=dropout,
                      dropout_steps=dropout_steps)
        return trainer

    def _gradient_accumulation(self, true_batches, normalization, total_stats,
                               report_stats):
        if self.accum_count > 1:
            self.optim.zero_grad()

        for k, batch in enumerate(true_batches):
            target_size = batch.tgt.size(0)
            # Truncated BPTT: reminder not compatible with accum > 1
            if self.trunc_size:
                trunc_size = self.trunc_size
            else:
                trunc_size = target_size

            src_list = list()
            src_lengths_list = list()
            for src_type in self.src_types:
                batch_src = getattr(batch, f"src.{src_type}")
                src, src_lengths = batch_src if isinstance(
                    batch_src, tuple) else (batch_src, None)
                if src_lengths is not None:
                    report_stats.n_src_words += src_lengths.sum().item()
                # end if
                src_list.append(src)
                src_lengths_list.append(src_lengths)
            # end for

            tgt_outer = batch.tgt

            bptt = False
            for j in range(0, target_size - 1, trunc_size):
                # 1. Create truncated target.
                tgt = tgt_outer[j:j + trunc_size]

                # 2. F-prop all but generator.
                if self.accum_count == 1:
                    self.optim.zero_grad()
                outputs, attns = self.model(src_list,
                                            tgt,
                                            src_lengths_list,
                                            bptt=bptt)
                bptt = True

                # 3. Compute loss.
                try:
                    loss, batch_stats = self.train_loss(
                        batch,
                        outputs,
                        attns,
                        normalization=normalization,
                        shard_size=self.shard_size,
                        trunc_start=j,
                        trunc_size=trunc_size)

                    if loss is not None:
                        self.optim.backward(loss)

                    total_stats.update(batch_stats)
                    report_stats.update(batch_stats)

                except Exception:
                    traceback.print_exc()
                    self.logger.info(
                        "At step %d, we removed a batch - accum %d",
                        self.optim.training_step, k)

                # 4. Update the parameters and statistics.
                if self.accum_count == 1:
                    # Multi GPU gradient gather
                    if self.n_gpu > 1:
                        grads = [
                            p.grad.data for p in self.model.parameters()
                            if p.requires_grad and p.grad is not None
                        ]
                        onmt.utils.distributed.all_reduce_and_rescale_tensors(
                            grads, float(1))
                    self.optim.step()

                # If truncated, don't backprop fully.
                # TO CHECK
                # if dec_state is not None:
                #    dec_state.detach()
                if self.model.decoder.state is not None:
                    self.model.decoder.detach_state()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.accum_count > 1:
            if self.n_gpu > 1:
                grads = [
                    p.grad.data for p in self.model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                onmt.utils.distributed.all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()

    def validate(self, valid_iter, moving_average=None):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        if moving_average:
            valid_model = deepcopy(self.model)
            for avg, param in zip(self.moving_average,
                                  valid_model.parameters()):
                param.data = avg.data
            else:
                valid_model = self.model

        # Set model in validating mode.
        valid_model.eval()

        with torch.no_grad():
            stats = onmt.utils.Statistics()

            for batch in valid_iter:
                src_list = list()
                src_lengths_list = list()
                for src_type in self.src_types:
                    batch_src = getattr(batch, f"src.{src_type}")
                    src, src_lengths = batch_src if isinstance(
                        batch_src, tuple) else (batch_src, None)
                    src_list.append(src)
                    src_lengths_list.append(src_lengths)
                # end for
                tgt = batch.tgt

                # F-prop through the model.
                outputs, attns = valid_model(src_list, tgt, src_lengths_list)

                # Compute loss.
                _, batch_stats = self.valid_loss(batch, outputs, attns)

                # Update statistics.
                stats.update(batch_stats)

        if moving_average:
            del valid_model
        else:
            # Set model back to training mode.
            valid_model.train()

        return stats
示例#28
0
class Utils:
    """
    Some utilities that doesn't tie to a specific other file. TODO: move them into seutil at some point.
    """
    logger = LoggingUtils.get_logger(__name__)

    @classmethod
    def get_option_as_boolean(cls, options, opt, default=False) -> bool:
        if opt not in options:
            return default
        else:
            # Due to limitations of CliUtils...
            return str(options.get(opt, "false")).lower() != "false"
        # end if

    @classmethod
    def get_option_as_list(cls, options, opt, default=None) -> list:
        if opt not in options:
            return copy.deepcopy(default)
        else:
            l = options[opt]
            if not isinstance(l, list): l = [l]
            return l
        # end if

    SUMMARIES_FUNCS: Dict[str, Callable[[Union[list, np.ndarray]], Union[
        int, float]]] = {
            "AVG":
            lambda l: np.mean(l) if len(l) > 0 else np.NaN,
            "SUM":
            lambda l: sum(l) if len(l) > 0 else np.NaN,
            "MAX":
            lambda l: max(l) if len(l) > 0 else np.NaN,
            "MIN":
            lambda l: min(l) if len(l) > 0 else np.NaN,
            "MEDIAN":
            lambda l: np.median(l)
            if len(l) > 0 and np.NaN not in l else np.NaN,
            "STDEV":
            lambda l: np.std(l) if len(l) > 0 else np.NaN,
            "CNT":
            lambda l: len(l),
        }

    SUMMARIES_PRESERVE_INT: Dict[str, bool] = {
        "AVG": False,
        "SUM": True,
        "MAX": True,
        "MIN": True,
        "MEDIAN": False,
        "STDEV": False,
        "CNT": True,
    }

    RE_GITHUB_URL = re.compile(
        r"https://github\.com/(?P<user>[^/]*)/(?P<repo>.*)\.git")

    @classmethod
    def lod_to_dol(cls, list_of_dict: List[dict]) -> Dict[Any, List]:
        """
        Converts a list of dict to a dict of list.
        """
        keys = set.union(*[set(d.keys()) for d in list_of_dict])
        return {k: [d.get(k) for d in list_of_dict] for k in keys}

    @classmethod
    def counter_most_common_to_pretty_yaml(
            cls, most_common: List[Tuple[Any, int]]) -> str:
        s = "[\n"
        for x, c in most_common:
            s += f"[{json.dumps(x)}, {c}],\n"
        # end for
        s += "]\n"
        return s

    @classmethod
    def modify_and_import(cls, module_name, package, modification_func):
        spec = importlib.util.find_spec(module_name, package)
        source = spec.loader.get_source(module_name)
        new_source = modification_func(source)
        module = importlib.util.module_from_spec(spec)
        codeobj = compile(new_source, module.__spec__.origin, 'exec')
        exec(codeobj, module.__dict__)
        sys.modules[module_name] = module
        return module
示例#29
0
class Table:
    logger = LoggingUtils.get_logger(
        __name__,
        LoggingUtils.DEBUG if Environment.is_debug else LoggingUtils.INFO)

    COLSEP = "COLSEP"
    ROWSEP = "ROWSEP"

    SYMBOLS = [
        r"\alpha",
        r"\beta",
        r"\gamma",
        r"\delta",
        r"\epsilon",
        r"\zeta",
        r"\eta",
        r"\theta",
        r"\iota",
        r"\kappa",
        r"\lambda",
        r"\mu",
        r"\nu",
        r"\tau",
        r"\pi",
        r"\rho",
    ]

    def __init__(self):
        self.tables_dir: Path = Macros.paper_dir / "tables"
        IOUtils.mk_dir(self.tables_dir)

        self.metrics_dir: Path = Macros.results_dir / "metrics"
        return

    def make_tables(self, which, options):
        if len(which) == 1 and which[0] == "all":
            which = ["dataset-metrics"]
        # end if

        for item in which:
            if item == "dataset-metrics":
                self.make_numbers_dataset_metrics()
                self.make_table_dataset_metrics(version="main")
                self.make_table_dataset_metrics(version="split")
            elif item == "draft-model-results":
                # TODO: outdated (->remove)
                results_path = Path(options.get("results-path"))
                output_name = options.get("output-name")
                self.make_table_draft_model_results(results_path, output_name)
            elif item == "time-wise-dataset-metrics":
                # TODO: outdated (->archive)
                self.make_numbers_timewise_dataset_metrics()
                self.make_table_timewise_dataset_metrics()
            elif item == "time-wise-filtered-dataset-metrics":
                # TODO: outdated (->archive)
                self.make_numbers_timewise_filtered_dataset_metrics(
                    options.get("dataset"), options.get("filter"))
                self.make_table_timewise_filtered_dataset_metrics(
                    options.get("dataset"), options.get("filter"))
            elif item == "models-numbers":
                model = options.get("model")
                self.make_numbers_model_results(model)
            elif item == "evo-models-results":
                # TODO: outdated (->archive)
                self.make_table_evo_results()
            elif item == "models-results":
                self.make_table_models_results(options.get("task"))
            else:
                self.logger.warning(f"Unknown table name is {item}")
            # end if
        # end for
        return

    def make_table_timewise_filtered_dataset_metrics(self,
                                                     dataset: str = "large",
                                                     filter: str = "beta"):
        years = range(2013, 2020)
        t_diffs = [f"{t}_Jan_1-{t + 1}_Jan_1" for t in years]
        time_points = [f"{t}-{t + 1}" for t in years]
        # Header
        if filter == "beta":
            file = latex.File(self.tables_dir / (
                f"table-time-wise-{filter}-filtered-{dataset}-dataset-metrics.tex"
            ))
            file.append(r"\begin{table*}")
            file.append(r"\begin{small}")
            file.append(r"\begin{center}")
            caption = f"Method naming statistics after filtering"
            file.append(r"\caption{" + caption + "}")
            file.append(r"\begin{tabular}{l | c }")
            file.append(r"\toprule")

            file.append(r" ")
            for m in ["num-methods"]:
                file.append(r" &")
                file.append(f"{m} ")
            file.append(r" \\")
            file.append(r"\midrule")
            for time, t in zip(t_diffs, time_points):
                file.append(f"{t}")
                file.append(" & " + latex.Macro(
                    f"{dataset}-{filter}-{time}-num-methods").use())
                file.append(r"\\")
            # Footer
            file.append(r"\bottomrule")
            file.append(r"\end{tabular}")
            file.append(r"\end{center}")
            file.append(r"\end{small}")
            file.append(r"\vspace{\TVDatasetMetrics}")
            file.append(r"\end{table*}")

            file.save()
            return
        for item in ["method", "comment"]:
            file = latex.File(self.tables_dir / (
                f"table-time-wise-{filter}-filtered-{item}-{dataset}-dataset-metrics.tex"
            ))
            file.append(r"\begin{table*}")
            file.append(r"\begin{small}")
            file.append(r"\begin{center}")
            caption = f"{item} statistics after filtering"
            file.append(r"\caption{" + caption + "}")
            file.append(r"\begin{tabular}{l | c c c c c c c}")
            file.append(r"\toprule")
            if item == "method":
                file.append(r" ")
                for m in [
                        "num-methods", "len-avg", "len-mode", "len-median",
                        "len<100", "len<150", "len<200"
                ]:
                    file.append(r" &")
                    file.append(f"{m} ")
                file.append(r" \\")
                file.append(r"\midrule")
                for time, t in zip(t_diffs, time_points):
                    file.append(f"{t}")
                    file.append(
                        " & " +
                        latex.Macro(f"{dataset}-{time}-num-methods").use())
                    for tvt in [
                            "avg", "mode", "median", "less-100", "less-150",
                            "less-200"
                    ]:
                        file.append(" & " + latex.Macro(
                            f"{dataset}-{time}-method-tokens-{tvt}").use())
                    file.append(r"\\")
                # Footer
                file.append(r"\bottomrule")
                file.append(r"\end{tabular}")
                file.append(r"\end{center}")
                file.append(r"\end{small}")
                file.append(r"\vspace{\TVDatasetMetrics}")
                file.append(r"\end{table*}")

                file.save()
            elif item == "comment":
                file.append(r" ")
                for m in [
                        "num-methods", "len-avg", "len-mode", "len-median",
                        "len<20", "len<30", "len<50"
                ]:
                    file.append(r" &")
                    file.append(f"{m} ")
                file.append(r" \\")
                file.append(r"\midrule")
                for time, t in zip(t_diffs, time_points):
                    file.append(f"{t}")
                    file.append(
                        " & " +
                        latex.Macro(f"{dataset}-{time}-num-methods").use())
                    for tvt in [
                            "avg", "mode", "median", "less-20", "less-30",
                            "less-50"
                    ]:
                        file.append(" & " + latex.Macro(
                            f"{dataset}-{time}-{item}-tokens-{tvt}").use())
                    file.append(r"\\")
                # Footer
                file.append(r"\bottomrule")
                file.append(r"\end{tabular}")
                file.append(r"\end{center}")
                file.append(r"\end{small}")
                file.append(r"\vspace{\TVDatasetMetrics}")
                file.append(r"\end{table*}")

                file.save()
        return

    def make_table_timewise_dataset_metrics(self, dataset: str = "large"):
        file = latex.File(self.tables_dir /
                          (f"table-time-wise-{dataset}-dataset-metrics.tex"))
        years = range(2013, 2021)
        # Header
        file.append(r"\begin{table*}")
        file.append(r"\begin{small}")
        file.append(r"\begin{center}")
        caption = r"Dataset statistics " + dataset
        file.append(r"\caption{" + caption + "}")
        file.append(r"\begin{tabular}{l | r r r r r r r r}")
        file.append(r"\toprule")

        file.append(r"  &"
                    r"2013 & "
                    r"2014 & "
                    r"2015 & "
                    r"2016 & "
                    r"2017 & "
                    r"2018 & "
                    r"2019 & "
                    r"2020 \\")
        file.append(r"\midrule")

        for tvt in ["num-methods", "num-projs", "delta"]:
            file.append(f"{tvt}")
            for m in years:
                key = f"{dataset}-{m}_Jan_1-{tvt}"
                file.append(" & " + latex.Macro(key).use())
            # end for
            file.append(r"\\")
        # end for

        # Footer
        file.append(r"\bottomrule")
        file.append(r"\end{tabular}")
        file.append(r"\end{center}")
        file.append(r"\end{small}")
        file.append(r"\vspace{\TVDatasetMetrics}")
        file.append(r"\end{table*}")

        file.save()
        return

    def make_numbers_model_results(self, model: str):
        file = latex.File(self.tables_dir / f"numbers-{model}-results.tex")
        stat_results = IOUtils.load(Macros.results_dir / "metrics" /
                                    f"results-stat-{model}.json")

        for exp, exp_stat_results in stat_results.items():
            for test_set, set_stat_results in exp_stat_results.items():
                for metric, metric_stat_results in set_stat_results.items():
                    for stat, number in metric_stat_results.items():
                        macro_name = f"{exp}-{test_set}-{metric}-{model}-{stat}"
                        if number == np.NaN or number == "NaN":
                            macro_value = r"\Fix{NaN}"
                        else:
                            macro_value = f"{number:,.2f}"
                        file.append_macro(latex.Macro(macro_name, macro_value))

        file.save()
        return

    def make_numbers_timewise_dataset_metrics(self, dataset: str = "large"):
        file = latex.File(self.tables_dir /
                          f"numbers-time-wise-{dataset}-dataset-metrics.tex")
        metrics = IOUtils.load(
            Macros.results_dir / "metrics" /
            f"time-wise-{dataset}-dataset-stats.json", IOUtils.Format.json)

        for t in metrics.keys():
            for k, v in metrics[t].items():
                file.append_macro(latex.Macro(f"{dataset}-{t}-{k}", f"{v}"))
        # end for

        file.save()
        return

    def make_numbers_timewise_filtered_dataset_metrics(self,
                                                       dataset: str = "large",
                                                       filter: str = "beta"):
        file = latex.File(
            self.tables_dir /
            f"numbers-time-wise-{filter}-filtered-{dataset}-dataset-metrics.tex"
        )
        metrics = IOUtils.load(
            Macros.results_dir / "metrics" /
            f"time-wise-{filter}-filtered-{dataset}-dataset-stats.json",
            IOUtils.Format.json)

        for t in metrics.keys():
            for k, v in metrics[t].items():
                if k == "num-methods":
                    file.append_macro(
                        latex.Macro(f"{dataset}-{filter}-{t}-{k}", f"{v}"))
                # TODO: change back
                """
                else:
                    file.append_macro(latex.Macro(f"{dataset}-{filter}-{t}-{k}", "{:.1f}".format(v)))
                """
        # end for

        file.save()
        return

    def make_numbers_dataset_metrics(self):
        for task in Macros.tasks:
            file = latex.File(self.tables_dir /
                              f"numbers-{task}-dataset-metrics.tex")

            dataset_metrics = IOUtils.load(
                Macros.results_dir / "metrics" / f"{task}-dataset.json",
                IOUtils.Format.json)
            for k, v in dataset_metrics.items():
                fmt = f",d" if type(v) == int else f",.2f"
                file.append_macro(latex.Macro(f"ds-{task}-{k}", f"{v:{fmt}}"))

            raw_dataset_metrics = IOUtils.load(
                Macros.results_dir / "metrics" / f"{task}-raw-dataset.json",
                IOUtils.Format.json)
            for k, v in raw_dataset_metrics.items():
                fmt = f",d" if type(v) == int else f",.2f"
                file.append_macro(
                    latex.Macro(f"raw-ds-{task}-{k}", f"{v:{fmt}}"))

            file.save()
        return

    def make_table_models_results(self, task: str):
        if task == "ComGen":
            models = ["Seq2seq", "Seq2seqAtt", "DeepCom"]
            metrics = ["bleu", "xmatch"]
        elif task == "MethNam":
            models = ["Bi-LSTM", "no-split-Bi-LSTM", "Code2Seq"]
            metrics = ["f1", "precision", "recall", "xmatch"]
        else:
            raise ValueError(f"Invalid task {task}")
        exps = ["mixedproj-2020", "crossproj-2020", "evo-2020"]

        # Load stat sign test results
        no_diff_pairs = IOUtils.load(Macros.results_dir / "metrics" /
                                     "sign-test" / f"{task}.json")
        exp_model_metric_2_symbols = collections.defaultdict(list)
        for i, (emm1, emm2, _) in enumerate(no_diff_pairs):
            symbol = self.SYMBOLS[i]
            exp_model_metric_2_symbols[tuple(emm1)].append(symbol)
            exp_model_metric_2_symbols[tuple(emm2)].append(symbol)

        file = latex.File(self.tables_dir / f"table-{task}-models-results.tex")

        # Header
        file.append(r"\begin{table*}")
        file.append(r"\begin{small}")
        file.append(r"\begin{center}")
        table_name = f"Results{task}"
        caption = r"\TC" + table_name
        file.append(r"\caption{" + caption + "}")
        file.append(r"\begin{tabular}{l" + ("|" + "r" * len(metrics)) * 3 +
                    "}")
        file.append(r"\toprule")

        # Line 1
        for i, exp in enumerate(exps):
            if i == len(exps) - 1:
                multicolumn = "c"
            else:
                multicolumn = "c|"
            file.append(r" & \multicolumn{" + f"{len(metrics)}" + r"}{" +
                        multicolumn + r"}{\UseMacro{TH-exp-" + exp + r"}}")
        file.append(r"\\")

        # Line 2
        file.append(r"\multirow{-2}{*}{\THModel} ")
        for exp in exps:
            for metric in metrics:
                file.append(r" & \UseMacro{TH-metric-" + metric + r"}")
        file.append(r"\\")

        file.append(r"\midrule")

        for model in models:
            file.append(r"\UseMacro{TH-model-" + model + r"}")
            for exp in exps:
                for metric in metrics:
                    suffix = ""
                    symbols = exp_model_metric_2_symbols[(exp, model, metric)]
                    if len(symbols) > 0:
                        suffix = "$^{" + "".join(symbols) + "}$"
                    file.append(r" & " + latex.Macro(
                        f"{exp}-test_common-{metric}-{model}-AVG").use() +
                                suffix)
                    # + r"$\pm$"
                    # + latex.Macro(f"{exp}-test_common-{metric}-{model}-STDEV").use())
            file.append(r"\\")

        # Footer
        file.append(r"\bottomrule")
        file.append(r"\end{tabular}")
        file.append(r"\end{center}")
        file.append(r"\end{small}")
        file.append(r"\vspace{\TV" + table_name + r"}")
        file.append(r"\end{table*}")

        file.save()
        return

    def make_table_methd_name_results(self, task="Method-naming"):
        models = ["Bi-LSTM", "no-split-Bi-LSTM", "Code2Seq"]
        metrics = ["precision", "recall", "f1"]
        file = latex.File(self.tables_dir / f"table-{task}-models-results.tex")
        # evo results
        years = range(13, 18)
        # Header
        file.append(r"\begin{table*}")
        file.append(r"\begin{small}")
        file.append(r"\begin{center}")
        caption = f"{task} results"
        file.append(r"\caption{" + caption + "}")
        # \begin{tabular}{l | c | c |c |c |c}
        coll = r"\begin{tabular}{l"
        for i in range(len(models) * 3):
            coll += "|c"
        coll += "}"
        file.append(coll)
        file.append(r"\toprule")

        file.append(r" \multirow{2}{*}{Time-Metrics}")
        for m in models:
            file.append(r"& \multicolumn{3}{c}" + f"{{{m}}}")
        file.append(r"\\")
        for i in range(len(models)):
            for metric in metrics:
                file.append(f"& {metric}")
        file.append(r"\\")
        file.append(r"\midrule")
        for t in years:
            file.append(f"20{t}-20{t + 1}-train")
            for m in models:
                for metric in metrics:
                    m = m.lower()
                    key = f"{m.lower()}-{t}{t + 1}-train-{metric}"
                    file.append(" & " + latex.Macro(key).use())
            file.append(r"\\")
            # end for
            # end for
        # end for

        file.append(f"latest-mixed")
        for m in models:
            for metric in metrics:
                m = m.lower()
                key = f"{m.lower()}-latest-{metric}"
                file.append(" & " + latex.Macro(key).use())
        file.append(r"\\")
        # end for

        file.append(f"latest-cross-project")
        for m in models:
            for metric in metrics:
                m = m.lower()
                key = f"{m.lower()}-cross-proj-latest-{metric}"
                file.append(" & " + latex.Macro(key).use())
        file.append(r"\\")
        # end for

        # Footer
        file.append(r"\bottomrule")
        file.append(r"\end{tabular}")
        file.append(r"\end{center}")
        file.append(r"\end{small}")
        file.append(r"\vspace{\TVDatasetMetrics}")
        file.append(r"\end{table*}")

        file.save()
        return

    def make_table_dataset_metrics(self, version: str):
        for task in Macros.tasks:
            if version == "main":
                file = latex.File(self.tables_dir /
                                  f"table-{task}-dataset-metrics-main.tex")
            elif version == "split":
                file = latex.File(self.tables_dir /
                                  f"table-{task}-dataset-metrics-split.tex")
            else:
                raise ValueError(f"Invalid version {version}")

            metric_2_th = collections.OrderedDict()
            metric_2_th[
                "num-proj"] = r"\multicolumn{2}{c|}{\UseMacro{TH-ds-num-project}}"
            metric_2_th[
                "num-data"] = r"\multicolumn{2}{c|}{\UseMacro{TH-ds-num-data}}"
            metric_2_th["len-meth-AVG"] = r"& \UseMacro{TH-ds-len-method-avg}"
            metric_2_th[
                "len-meth-MODE"] = r"& \UseMacro{TH-ds-len-method-mode}"
            metric_2_th[
                "len-meth-MEDIAN"] = r"& \UseMacro{TH-ds-len-method-median}"
            metric_2_th[
                "len-meth-le-100"] = r"& \UseMacro{TH-ds-len-method-le100}"
            metric_2_th[
                "len-meth-le-150"] = r"& \UseMacro{TH-ds-len-method-le150}"
            metric_2_th[
                "len-meth-le-200"] = r"\multirow{-6}{*}{\UseMacro{TH-ds-len-method}} & \UseMacro{TH-ds-len-method-le200}"
            metric_2_th["len-com-AVG"] = r"& \UseMacro{TH-ds-len-comment-avg}"
            metric_2_th[
                "len-com-MODE"] = r"& \UseMacro{TH-ds-len-comment-mode}"
            metric_2_th[
                "len-com-MEDIAN"] = r"& \UseMacro{TH-ds-len-comment-median}"
            metric_2_th[
                "len-com-le-20"] = r"& \UseMacro{TH-ds-len-comment-le20}"
            metric_2_th[
                "len-com-le-30"] = r"& \UseMacro{TH-ds-len-comment-le30}"
            metric_2_th[
                "len-com-le-50"] = r"\multirow{-6}{*}{\UseMacro{TH-ds-len-comment}} & \UseMacro{TH-ds-len-comment-le50}"
            metric_2_th["len-name-AVG"] = r"& \UseMacro{TH-ds-len-name-avg}"
            metric_2_th["len-name-MODE"] = r"& \UseMacro{TH-ds-len-name-mode}"
            metric_2_th[
                "len-name-MEDIAN"] = r"& \UseMacro{TH-ds-len-name-median}"
            metric_2_th["len-name-le-3"] = r"& \UseMacro{TH-ds-len-name-le2}"
            metric_2_th["len-name-le-5"] = r"& \UseMacro{TH-ds-len-name-le3}"
            metric_2_th[
                "len-name-le-6"] = r"\multirow{-6}{*}{\UseMacro{TH-ds-len-name}} & \UseMacro{TH-ds-len-name-le6}"

            sep_after_rows = [
                "num-data",
                "len-meth-le-200",
                "len-com-le-50",
            ]

            dt_2_is_raw = collections.OrderedDict()

            if version == "main":
                dt_2_is_raw["all"] = True
                dt_2_is_raw["2020"] = False
                dt_2_is_raw["2019-2020"] = False

                sep_after_cols = []
            elif version == "split":
                for exp in ["mixedproj", "crossproj", "evo"]:
                    for dt in [Macros.train, Macros.val]:
                        dt_2_is_raw[f"{exp}-2020-{dt}"] = False
                dt_2_is_raw[f"2020-{Macros.test_common}"] = False

                sep_after_cols = [
                    f"mixedproj-2020-{Macros.val}",
                    f"crossproj-2020-{Macros.val}",
                ]
            else:
                raise ValueError(f"Invalid version {version}")

            # Header
            file.append(r"\begin{" +
                        ("table*" if version == "split" else "table") + "}")
            file.append(r"\begin{small}")
            file.append(r"\begin{center}")

            if version == "main":
                table_name = "DatasetMetricsMain"
            elif version == "split":
                table_name = "DatasetMetricsSplit"
            else:
                raise ValueError(f"Invalid version {version}")

            file.append(r"\caption{\TC" + table_name + "}")

            if version == "main":
                file.append(
                    r"\begin{tabular}{ l@{\hspace{2pt}}|@{\hspace{2pt}}c@{\hspace{2pt}} | r r r}"
                )
            elif version == "split":
                file.append(
                    r"\begin{tabular}{ l@{\hspace{2pt}}|@{\hspace{2pt}}c@{\hspace{2pt}} | rr @{\hspace{5pt}}c@{\hspace{5pt}} rr @{\hspace{5pt}}c@{\hspace{5pt}} rr r}"
                )
            else:
                raise ValueError(f"Invalid version {version}")

            file.append(r"\toprule")

            if version == "main":
                # Line 1
                file.append(r"\multicolumn{2}{c|}{} & & & \\")

                # Line 2
                file.append(
                    r"\multicolumn{2}{c|}{\multirow{-2}{*}{\THDSStat}} & \multirow{-2}{*}{\UseMacro{TH-ds-all}} & \multirow{-2}{*}{\UseMacro{TH-ds-2020}} & \multirow{-2}{*}{\UseMacro{TH-ds-2019-2020}} \\"
                )
            elif version == "split":
                # Line 1
                file.append(
                    r"\multicolumn{2}{c|}{}"
                    r" & \multicolumn{2}{c}{\UseMacro{TH-ds-mixedproj}} &"
                    r" & \multicolumn{2}{c}{\UseMacro{TH-ds-crossproj}} &"
                    r" & \multicolumn{2}{c}{\UseMacro{TH-ds-evo}}"
                    r" & \\\cline{3-4}\cline{6-7}\cline{9-10}")

                # Line 2
                file.append(
                    r"\multicolumn{2}{c|}{\multirow{-2}{*}{\THDSStat}}"
                    r" & \UseMacro{TH-ds-mixedproj-train} & \UseMacro{TH-ds-mixedproj-val} &"
                    r" & \UseMacro{TH-ds-crossproj-train} & \UseMacro{TH-ds-crossproj-val} &"
                    r" & \UseMacro{TH-ds-evo-train} & \UseMacro{TH-ds-evo-val}"
                    r" & \multirow{-2}{*}{\UseMacro{TH-ds-test}} \\")
            else:
                raise ValueError(f"Invalid version {version}")

            file.append(r"\midrule")

            for metric, row_th in metric_2_th.items():
                file.append(row_th)

                for dt, is_raw in dt_2_is_raw.items():
                    if metric == "num-proj":
                        if dt == f"crossproj-2020-{Macros.train}":
                            macro_name = f"ds-{task}-num-proj_{Macros.train}"
                        elif dt == f"crossproj-2020-{Macros.val}":
                            macro_name = f"ds-{task}-num-proj_{Macros.val}"
                        elif dt == f"2020-{Macros.test_common}":
                            macro_name = f"ds-{task}-num-proj_{Macros.test}"
                        else:
                            macro_name = f"ds-{task}-num-proj"
                    elif is_raw:
                        macro_name = f"raw-ds-{task}-{metric}_{dt}"
                    else:
                        macro_name = f"ds-{task}-{metric}_{dt}"

                    file.append(" & " + latex.Macro(macro_name).use())

                    if dt in sep_after_cols:
                        file.append(" & ")

                file.append(r"\\")

                if metric in sep_after_rows:
                    file.append(r"\midrule")

            # Footer
            file.append(r"\bottomrule")
            file.append(r"\end{tabular}")
            file.append(r"\end{center}")
            file.append(r"\end{small}")
            file.append(r"\vspace{\TV" + table_name + "}")
            file.append(r"\end{" +
                        ("table*" if version == "split" else "table") + "}")

            file.save()
        return

    def make_table_draft_model_results(
        self,
        results_path: Path,
        output_name: str,
    ):
        special_tables_dir = self.tables_dir / "draft-model-results"
        IOUtils.mk_dir(special_tables_dir)
        file = latex.File(special_tables_dir / f"{output_name}.tex")

        # Header
        file.append(r"\begin{table*}")
        file.append(r"\begin{small}")
        file.append(r"\begin{center}")
        file.append(r"\caption{Model Results (Draft) from " +
                    str(results_path).replace("_", r"\_") + "}")

        metrics = None
        for tvt in [
                Macros.lat_lat, Macros.evo_lat, Macros.lat_evo, Macros.evo_evo
        ]:
            results = IOUtils.load(results_path / tvt / "test_results.json")

            # Flatten Rouge scores
            if "Rouge" in results:
                if results["Rouge"] == 0:
                    results["Rouge1-F1"] = 0
                    results["Rouge2-F1"] = 0
                    results["RougeL-F1"] = 0
                else:
                    results["Rouge1-F1"] = results["Rouge"]["rouge-1"]["f"]
                    results["Rouge2-F1"] = results["Rouge"]["rouge-2"]["f"]
                    results["RougeL-F1"] = results["Rouge"]["rouge-l"]["f"]
                # end if
                del results["Rouge"]
            # end if

            if metrics is None:
                metrics = list(sorted(results.keys()))

                # Table header line
                file.append(r"\begin{tabular}{l | " + "r" * len(metrics) + "}")
                file.append(r"\toprule")
                file.append("Training-Testing & " + " & ".join(metrics) +
                            r"\\")
                file.append(r"\midrule")
            # end if

            file.append(tvt)
            for m in metrics:
                file.append(f"& {results[m]:.2f}")
            # end for
            file.append(r"\\")
        # end for

        # Footer
        file.append(r"\bottomrule")
        file.append(r"\end{tabular}")
        file.append(r"\end{center}")
        file.append(r"\end{small}")
        file.append(r"\end{table*}")

        file.save()
        return
示例#30
0
class FilesManager:
    """
    Handles the loading/dumping of files in a dataset.
    """
    logger = LoggingUtils.get_logger(__name__)

    ALL_LEMMAS_BACKEND_SEXP_TRANSFORMATIONS = "all-lemmas-bsexp-transformations"
    ALL_LEMMAS_FOREEND_SEXP_TRANSFORMATIONS = "all-lemmas-fsexp-transformations"
    COQ_DOCUMENTS = "coq-documents"
    LEMMAS = "lemmas"
    LEMMAS_BACKEND_SEXP_TRANSFORMATIONS = "lemmas-bsexp-transformations"
    LEMMAS_FILTERED = "lemmas-filtered"
    LEMMAS_FOREEND_SEXP_TRANSFORMATIONS = "lemmas-fsexp-transformations"
    DATA_INDEXES = "data-indexes"
    RAW_FILES = "raw-files"
    ORIGINAL_FILES = "original-files"
    DEFINITIONS = "definitions"

    def __init__(self, data_dir: Path):
        self.data_dir = data_dir
        self.data_dir.mkdir(parents=True, exist_ok=True)
        return

    def clean_path(self, rel_path: Union[str, List[str]]):
        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        if abs_path.exists():
            self.logger.info(f"Removing existing things at {abs_path}")
            IOUtils.rm(abs_path)
        # end if
        return

    @classmethod
    def is_json_format(cls, fmt: IOUtils.Format) -> bool:
        return fmt in [IOUtils.Format.json, IOUtils.Format.jsonPretty, IOUtils.Format.jsonNoSort]

    def dump_data(self,
            rel_path: Union[str, List[str]],
            data: Any,
            fmt: IOUtils.Format,
            is_batched: bool = False,
            per_batch: int = 100,
            exist_ok: bool = False,
    ):
        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        if abs_path.exists() and not exist_ok:
            LoggingUtils.log_and_raise(self.logger, f"Cannot rewrite existing data at {abs_path}", IOError)
        # end if

        abs_path.parent.mkdir(parents=True, exist_ok=True)
        if not is_batched:
            if self.is_json_format(fmt):
                data = IOUtils.jsonfy(data)
            # end if
            IOUtils.dump(abs_path, data, fmt)
        else:
            # In batched mode, the data need to be slice-able and sizable
            IOUtils.rm(abs_path)
            abs_path.mkdir(parents=True)

            for batch_i in tqdm(range(math.ceil(len(data)/per_batch))):
                data_batch = data[per_batch*batch_i : per_batch*(batch_i+1)]
                if self.is_json_format(fmt):
                    data_batch = IOUtils.jsonfy(data_batch)
                # end if
                IOUtils.dump(abs_path/f"batch-{batch_i}.{fmt.get_extension()}", data_batch, fmt)
            # end for
        # end if
        return

    def load_data(self,
            rel_path: Union[str, List[str]],
            fmt: IOUtils.Format,
            is_batched: bool = False,
            clz = None,
    ) -> Any:
        if self.is_json_format(fmt) and clz is None:
            self.logger.warning(f"Load data from {rel_path} with json format, but did not specify clz (at {traceback.format_stack()})")
        # end if

        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        if not abs_path.exists():
            LoggingUtils.log_and_raise(self.logger, f"Cannot find data at {abs_path}", IOError)
        # end if

        if not is_batched:
            data = IOUtils.load(abs_path, fmt)
            if self.is_json_format(fmt) and clz is not None:
                data = IOUtils.dejsonfy(data, clz)
            # end if
            return data
        else:
            data = list()
            batch_numbers = sorted([int(str(f.stem).split("-")[1]) for f in abs_path.iterdir()])
            for batch_number in tqdm(batch_numbers):
                batch_file = abs_path / f"batch-{batch_number}.{fmt.get_extension()}"
                data_batch = IOUtils.load(batch_file, fmt)
                if self.is_json_format(fmt) and clz is not None:
                    data_batch = IOUtils.dejsonfy(data_batch, clz)
                # end if
                data.extend(data_batch)
            # end for
            return data
        # end if

    def iter_batched_data(self,
            rel_path: Union[str, List[str]],
            fmt: IOUtils.Format,
            clz = None,
    ) -> Iterator:
        if self.is_json_format(fmt) and clz is None:
            self.logger.warning(f"Load data from {rel_path} with json format, but did not specify clz")
        # end if

        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        if not abs_path.exists():
            LoggingUtils.log_and_raise(self.logger, f"Cannot find data at {abs_path}", IOError)
        # end if

        batch_numbers = sorted([int(str(f.stem).split("-")[1]) for f in abs_path.iterdir()])
        for batch_number in batch_numbers:
            batch_file = abs_path / f"batch-{batch_number}.{fmt.get_extension()}"
            for data_entry in IOUtils.load(batch_file, fmt):
                if self.is_json_format(fmt) and clz is not None:
                    data_entry = IOUtils.dejsonfy(data_entry, clz)
                # end if
                yield data_entry
            # end for
        # end for

    def dump_ckpt(self, rel_path: Union[str, List[str]], obj: Any, ckpt_id: int,
            dump_func: Callable[[Any, str], NoReturn],
            ckpt_keep_max: int = 5,
    ) -> NoReturn:
        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        abs_path.mkdir(parents=True, exist_ok=True)

        ckpt_file_name = str(abs_path / str(ckpt_id))
        dump_func(obj, ckpt_file_name)

        # Remove older checkpoints
        if ckpt_keep_max != -1:
            ckpt_ids = [int(str(f.name)) for f in abs_path.iterdir()]
            for ckpt_id in sorted(ckpt_ids)[:-ckpt_keep_max]:
                IOUtils.rm(abs_path / str(ckpt_id))
            # end for
        # end if
        return

    def load_ckpt(self, rel_path: Union[str, List[str]],
            load_func: Callable[[str], Any],
            ckpt_id: Optional[int] = None,
    ) -> Any:
        abs_path = self.data_dir / self.assemble_rel_path(rel_path)
        if not abs_path.exists():
            LoggingUtils.log_and_raise(self.logger, f"Cannot find data at {abs_path}", IOError)
        # end if

        if ckpt_id is None:
            # Find the latest ckpt
            ckpt_ids = [int(str(f.name)) for f in abs_path.iterdir()]
            ckpt_id = max(ckpt_ids)
            self.logger.info(f"Loading the latest checkpoint {ckpt_id} at {abs_path}")
        # end if

        return load_func(str(abs_path / str(ckpt_id)))

    def resolve(self, rel_path: Union[str, List[str]]) -> Path:
        return self.data_dir / self.assemble_rel_path(rel_path)

    @classmethod
    def assemble_rel_path(cls, rel_path: Union[str, List[str]]) -> str:
        if not isinstance(rel_path, str):
            rel_path = "/".join(rel_path)
        # end if
        return rel_path