コード例 #1
0
    def __init__(self, config, *args, **kwargs):
        if not hasattr(config, "vocab_file"):
            raise AttributeError("'vocab_file' argument required, but not "
                                 "present in AnswerProcessor's config")

        self.answer_vocab = VocabDict(config.vocab_file, *args, **kwargs)
        self.PAD_IDX = self.answer_vocab.word2idx("<pad>")
        self.BOS_IDX = self.answer_vocab.word2idx("<s>")
        self.EOS_IDX = self.answer_vocab.word2idx("</s>")
        self.UNK_IDX = self.answer_vocab.UNK_INDEX

        # Set EOS to something not achievable if it is not there
        if self.EOS_IDX == self.UNK_IDX:
            self.EOS_IDX = len(self.answer_vocab)

        self.preprocessor = None

        if hasattr(config, "preprocessor"):
            self.preprocessor = Processor(config.preprocessor)

            if self.preprocessor is None:
                raise ValueError(
                    f"No processor named {config.preprocessor} is defined.")

        if hasattr(config, "num_answers"):
            self.num_answers = config.num_answers
        else:
            self.num_answers = self.DEFAULT_NUM_ANSWERS
            warnings.warn("'num_answers' not defined in the config. "
                          "Setting to default of {}".format(
                              self.DEFAULT_NUM_ANSWERS))
コード例 #2
0
ファイル: processors.py プロジェクト: zhouweixin/mmf
    def __init__(self, config, *args, **kwargs):
        self.writer = registry.get("writer")
        if not hasattr(config, "vocab_file"):
            raise AttributeError("'vocab_file' argument required, but not "
                                 "present in AnswerProcessor's config")

        self.answer_vocab = VocabDict(config.vocab_file, *args, **kwargs)

        self.preprocessor = None

        if hasattr(config, "preprocessor"):
            self.preprocessor = Processor(config.preprocessor)

            if self.preprocessor is None:
                raise ValueError(
                    f"No processor named {config.preprocessor} is defined.")

        if hasattr(config, "num_answers"):
            self.num_answers = config.num_answers
        else:
            self.num_answers = self.DEFAULT_NUM_ANSWERS
            warnings.warn("'num_answers' not defined in the config. "
                          "Setting to default of {}".format(
                              self.DEFAULT_NUM_ANSWERS))
コード例 #3
0
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)

        self.answer_vocab = VocabDict(config.vocab_file, *args, **kwargs)
        self.PAD_IDX = self.answer_vocab.word2idx("<pad>")
        self.BOS_IDX = self.answer_vocab.word2idx("<s>")
        self.EOS_IDX = self.answer_vocab.word2idx("</s>")
        self.UNK_IDX = self.answer_vocab.UNK_INDEX

        # make sure PAD_IDX, BOS_IDX and PAD_IDX are valid (not <unk>)
        assert self.PAD_IDX != self.answer_vocab.UNK_INDEX
        assert self.BOS_IDX != self.answer_vocab.UNK_INDEX
        assert self.EOS_IDX != self.answer_vocab.UNK_INDEX
        assert self.PAD_IDX == 0

        self.answer_preprocessor = Processor(config.preprocessor)
        assert self.answer_preprocessor is not None

        self.num_answers = config.num_answers
        self.max_length = config.max_length
        self.max_copy_steps = config.max_copy_steps
        assert self.max_copy_steps >= 1

        self.match_answer_to_unk = False
コード例 #4
0
ファイル: processors.py プロジェクト: weexiaolong/mmf
 def __init__(self, config: MultiClassFromFileConfig, *args, **kwargs):
     self.label_vocab = VocabDict(config.vocab_file, *args, **kwargs)