def setup_task(cls, args: argparse.Namespace, **kwargs): data_dict = cls.load_dictionary(args, os.path.join(args.data, "dict.txt")) logger.info("[input] dictionary: {} types".format(len(data_dict))) is_word_initial = get_word_beginnings(args, data_dict) term_dict = cls.load_dictionary(args, os.path.join(args.data, "dict_term.txt"), add_mask=False) # label_dict, label_schema = cls.load_label_dictionary(args, args.term_schema) _, label_schema = cls.load_label_dictionary(args, args.term_schema) logger.info("[label] dictionary: {} types".format(len(term_dict))) seen = set() for idx, lbl in enumerate(term_dict.symbols): exists = lbl in label_schema.labels seen.add(lbl) if ((not exists) and (idx > term_dict.nspecial) # ignore bos, eos, etc and (lbl != "<mask>") and (lbl.startswith("madeupword")) # ignore vocabulary padding ): assert False, "Unexpected POS label item in term_dict.txt: {}".format( lbl) for lbl in label_schema.labels: if lbl in seen: continue assert False, "Unexpected POS label item in label_schema {}".format( lbl) return MultiLabelTokenClassificationTask(args, data_dict, term_dict, is_word_initial, label_schema)
def setup_task(cls, args: argparse.Namespace, **kwargs): data_dict = Dictionary.load(os.path.join(args.data, "dict.txt")) data_dict.add_symbol("<mask>") logger.info("[input] dictionary: {} types".format(len(data_dict))) is_word_initial = get_word_beginnings(args, data_dict) term_dict = Dictionary.load(os.path.join(args.data, "dict_term.txt")) logger.info("[label] dictionary: {} types".format(len(term_dict))) return MultiClassTokenClassificationTask(args, data_dict, term_dict, is_word_initial)
def __init__( self, args: argparse.Namespace, src_dict: Dictionary, tgt_dict: Dictionary, ): super().__init__(args, src_dict, tgt_dict) # type: ignore config = GlossaryTaskConfig.from_args(args) if config.enabled: logger.info("Glossary is ENABLED") logger.info(f"Glossary config: {config}") else: logger.info("Glossary is DISABLED") self.glossary_task_config = config # Ensure that <sep> and <c> are defined in the dictionaries. ensure_symbols_are_present( self.source_dictionary, ["<c>", "<sep>"], self.glossary_task_config.ok_to_increase_dict_size, ) ensure_symbols_are_present( self.target_dictionary, ["<c>", "<sep>"], self.glossary_task_config.ok_to_increase_dict_size, ) assert ( self.target_dictionary == self.source_dictionary ), "The target dictionary must be the same as the source dictionary, \ because we use is_word_initial based on a single dictionary and use it for both src and tgt." is_word_initial = get_word_beginnings(args, self.source_dictionary) if is_word_initial is None: raise ValueError("The is_word_initial function is None.") self.is_word_initial = is_word_initial apply_monkey_patch_for_make_positions( positional_marker_symbol_idx=self.source_dictionary.index("<sep>"), positional_idx_restart_offset=self.glossary_task_config. constraint_positional_start_idx, ) self.bpe = encoders.build_bpe(args)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.word_start_dict = get_word_beginnings(self.args, self.task.dictionary)