def __init__(self, paths: dict=None): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev': 'dev.tsv', 'test': 'test.tsv', } MatchingLoader.__init__(self, paths=paths) CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))
def _load(self, path): ds = JsonLoader._load(self, path) parentheses_table = str.maketrans({'(': None, ')': None}) ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), new_field_name=Const.INPUTS(0)) ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), new_field_name=Const.INPUTS(1)) ds.drop(lambda x: x[Const.TARGET] == '-') return ds
def __init__(self, paths: dict=None): fields = { 'sentence1_binary_parse': Const.INPUTS(0), 'sentence2_binary_parse': Const.INPUTS(1), 'gold_label': Const.TARGET, } paths = paths if paths is not None else { 'train': 'snli_1.0_train.jsonl', 'dev': 'snli_1.0_dev.jsonl', 'test': 'snli_1.0_test.jsonl'} MatchingLoader.__init__(self, paths=paths) JsonLoader.__init__(self, fields=fields)
def __init__(self, paths: dict = None): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev': 'dev.tsv', 'test': 'test.tsv' # test set has not label } MatchingLoader.__init__(self, paths=paths) self.fields = { 'question': Const.INPUTS(0), 'sentence': Const.INPUTS(1), 'label': Const.TARGET, } CSVLoader.__init__(self, sep='\t')
def __init__(self, paths: dict = None): paths = paths if paths is not None else { 'train': 'train.tsv', 'dev_matched': 'dev_matched.tsv', 'dev_mismatched': 'dev_mismatched.tsv', 'test_matched': 'test_matched.tsv', 'test_mismatched': 'test_mismatched.tsv', # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) } MatchingLoader.__init__(self, paths=paths) CSVLoader.__init__(self, sep='\t') self.fields = { 'sentence1_binary_parse': Const.INPUTS(0), 'sentence2_binary_parse': Const.INPUTS(1), 'gold_label': Const.TARGET, }
def _load(self, path): ds = CSVLoader._load(self, path) for k, v in self.fields.items(): if k in ds.get_field_names(): ds.rename_field(k, v) if Const.TARGET in ds.get_field_names(): if ds[0][Const.TARGET] == 'hidden': ds.delete_field(Const.TARGET) parentheses_table = str.maketrans({'(': None, ')': None}) ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), new_field_name=Const.INPUTS(0)) ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), new_field_name=Const.INPUTS(1)) if Const.TARGET in ds.get_field_names(): ds.drop(lambda x: x[Const.TARGET] == '-') return ds
def prepare_nli_data(self): index = 'index' ds = DataSet({index: list(range(N_SAMPLES))}) ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), field_name=index, new_field_name=C.INPUTS(0), is_input=True) ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), field_name=index, new_field_name=C.INPUTS(1), is_input=True) ds.apply_field(lambda x: randrange(NUM_CLS), field_name=index, new_field_name=C.TARGET, is_target=True) ds.apply_field(len, C.INPUTS(0), C.INPUT_LENS(0), is_input=True, is_target=True) ds.apply_field(len, C.INPUTS(1), C.INPUT_LENS(1), is_input = True, is_target = True) ds.set_input(C.INPUTS(0), C.INPUTS(1)) ds.set_target(C.TARGET) return ds
def cache(): bundle = CoReferencePipe(config).process_from_file({ 'train': config.train_path, 'dev': config.dev_path, 'test': config.test_path }) return bundle data_bundle = cache() print(data_bundle) model = Model(data_bundle.get_vocab(Const.INPUTS(0)), config) print(model) loss = SoftmaxLoss() metric = CRMetric() optim = Adam(model.parameters(), lr=config.lr) lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) trainer = Trainer( model=model, train_data=data_bundle.datasets["train"], dev_data=data_bundle.datasets["dev"], loss=loss,
def process(self, paths): def truncate_articles(instance, max_nsents=self.max_nsents, max_ntokens=self.max_ntokens): article = [ ' '.join(sent.lower().split()[:max_ntokens]) for sent in instance['article'] ] return article[:max_nsents] def truncate_labels(instance): label = list( filter(lambda x: x < len(instance['article']), instance['label'])) return label def bert_tokenize(instance, tokenizer, max_len, pad_value): article = instance['article'] article = ' [SEP] [CLS] '.join(article) word_pieces = tokenizer.tokenize(article)[:(max_len - 2)] word_pieces = ['[CLS]'] + word_pieces + ['[SEP]'] token_ids = tokenizer.convert_tokens_to_ids(word_pieces) while len(token_ids) < max_len: token_ids.append(pad_value) assert len(token_ids) == max_len return token_ids def get_seg_id(instance, max_len, sep_id): _segs = [-1] + [ i for i, idx in enumerate(instance['article']) if idx == sep_id ] segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] segment_id = [] for i, length in enumerate(segs): if i % 2 == 0: segment_id += length * [0] else: segment_id += length * [1] while len(segment_id) < max_len: segment_id.append(0) return segment_id def get_cls_id(instance, cls_id): classification_id = [ i for i, idx in enumerate(instance['article']) if idx == cls_id ] return classification_id def get_labels(instance): labels = [0] * len(instance['cls_id']) label_idx = list( filter(lambda x: x < len(instance['cls_id']), instance['label'])) for idx in label_idx: labels[idx] = 1 return labels datasets = {} for name in paths: datasets[name] = self._load(paths[name]) # remove empty samples datasets[name].drop( lambda ins: len(ins['article']) == 0 or len(ins['label']) == 0) # truncate articles datasets[name].apply(lambda ins: truncate_articles( ins, self.max_nsents, self.max_ntokens), new_field_name='article') # truncate labels datasets[name].apply(truncate_labels, new_field_name='label') # tokenize and convert tokens to id datasets[name].apply(lambda ins: bert_tokenize( ins, self.tokenizer, self.max_len, self.pad_id), new_field_name='article') # get segment id datasets[name].apply( lambda ins: get_seg_id(ins, self.max_len, self.sep_id), new_field_name='segment_id') # get classification id datasets[name].apply(lambda ins: get_cls_id(ins, self.cls_id), new_field_name='cls_id') # get label datasets[name].apply(get_labels, new_field_name='label') # rename filed datasets[name].rename_field('article', Const.INPUTS(0)) datasets[name].rename_field('segment_id', Const.INPUTS(1)) datasets[name].rename_field('cls_id', Const.INPUTS(2)) datasets[name].rename_field('lbael', Const.TARGET) # set input and target datasets[name].set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUTS(2)) datasets[name].set_target(Const.TARGET) # set paddding value datasets[name].set_pad_val('article', 0) return DataInfo(datasets=datasets)
def process( self, paths: Union[str, Dict[str, str]], dataset_name: str = None, to_lower=False, seq_len_type: str = None, bert_tokenizer: str = None, cut_text: int = None, get_index=True, auto_pad_length: int = None, auto_pad_token: str = '<pad>', set_input: Union[list, str, bool] = True, set_target: Union[list, str, bool] = True, concat: Union[str, list, bool] = None, ) -> DataBundle: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 对应的全路径文件名。 :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 这个数据集的名字,如果不定义则默认为train。 :param bool to_lower: 是否将文本自动转为小写。默认值为False。 :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 :param bool get_index: 是否需要根据词表将文本转为index :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad :param str auto_pad_token: 自动pad的内容 :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, 于此同时其他field不会被设置为input。默认值为True。 :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. :return: """ if isinstance(set_input, str): set_input = [set_input] if isinstance(set_target, str): set_target = [set_target] if isinstance(set_input, bool): auto_set_input = set_input else: auto_set_input = False if isinstance(set_target, bool): auto_set_target = set_target else: auto_set_target = False if isinstance(paths, str): if os.path.isdir(paths): path = { n: os.path.join(paths, self.paths[n]) for n in self.paths.keys() } else: path = { dataset_name if dataset_name is not None else 'train': paths } else: path = paths data_info = DataBundle() for data_name in path.keys(): data_info.datasets[data_name] = self._load(path[data_name]) for data_name, data_set in data_info.datasets.items(): if auto_set_input: data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) if auto_set_target: if Const.TARGET in data_set.get_field_names(): data_set.set_target(Const.TARGET) if to_lower: for data_name, data_set in data_info.datasets.items(): data_set.apply( lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), is_input=auto_set_input) data_set.apply( lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), is_input=auto_set_input) if bert_tokenizer is not None: if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: PRETRAIN_URL = _get_base_url('bert') model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] model_url = PRETRAIN_URL + model_name model_dir = cached_path(model_url) # 检查是否存在 elif os.path.isdir(bert_tokenizer): model_dir = bert_tokenizer else: raise ValueError( f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: lines = f.readlines() lines = [line.strip() for line in lines] words_vocab.add_word_lst(lines) words_vocab.build_vocab() tokenizer = BertTokenizer.from_pretrained(model_dir) for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): if Const.INPUT in fields: data_set.apply( lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, is_input=auto_set_input) if isinstance(concat, bool): concat = 'default' if concat else None if concat is not None: if isinstance(concat, str): CONCAT_MAP = { 'bert': ['[CLS]', '[SEP]', '', '[SEP]'], 'default': ['', '<sep>', '', ''] } if concat.lower() in CONCAT_MAP: concat = CONCAT_MAP[concat] else: concat = 4 * [concat] assert len(concat) == 4, \ f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ f'the end of first sentence, the begin of second sentence, and the end of second' \ f'sentence. Your input is {concat}' for data_name, data_set in data_info.datasets.items(): data_set.apply( lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[ 1]] + [concat[2]] + x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) data_set.apply( lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, is_input=auto_set_input) if seq_len_type is not None: if seq_len_type == 'seq_len': # for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): if Const.INPUT in fields: data_set.apply(lambda x: len(x[fields]), new_field_name=fields.replace( Const.INPUT, Const.INPUT_LEN), is_input=auto_set_input) elif seq_len_type == 'mask': for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): if Const.INPUT in fields: data_set.apply(lambda x: [1] * len(x[fields]), new_field_name=fields.replace( Const.INPUT, Const.INPUT_LEN), is_input=auto_set_input) elif seq_len_type == 'bert': for data_name, data_set in data_info.datasets.items(): if Const.INPUT not in data_set.get_field_names(): raise KeyError( f'Field ``{Const.INPUT}`` not in {data_name} data set: ' f'got {data_set.get_field_names()}') data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) if auto_pad_length is not None: cut_text = min( auto_pad_length, cut_text if cut_text is not None else auto_pad_length) if cut_text is not None: for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): data_set.apply(lambda x: x[fields][:cut_text], new_field_name=fields, is_input=auto_set_input) data_set_list = [d for n, d in data_info.datasets.items()] assert len(data_set_list) > 0, f'There are NO data sets in data info!' if bert_tokenizer is None: words_vocab = Vocabulary(padding=auto_pad_token) words_vocab = words_vocab.from_dataset( *[d for n, d in data_info.datasets.items() if 'train' in n], field_name=[ n for n in data_set_list[0].get_field_names() if (Const.INPUT in n) ], no_create_entry_dataset=[ d for n, d in data_info.datasets.items() if 'train' not in n ]) target_vocab = Vocabulary(padding=None, unknown=None) target_vocab = target_vocab.from_dataset( *[d for n, d in data_info.datasets.items() if 'train' in n], field_name=Const.TARGET) data_info.vocabs = { Const.INPUT: words_vocab, Const.TARGET: target_vocab } if get_index: for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): if Const.INPUT in fields: data_set.apply( lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, is_input=auto_set_input) if Const.TARGET in data_set.get_field_names(): data_set.apply( lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, is_input=auto_set_input, is_target=auto_set_target) if auto_pad_length is not None: if seq_len_type == 'seq_len': raise RuntimeError( f'the sequence will be padded with the length {auto_pad_length}, ' f'so the seq_len_type cannot be `{seq_len_type}`!') for data_name, data_set in data_info.datasets.items(): for fields in data_set.get_field_names(): if Const.INPUT in fields: data_set.apply( lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * (auto_pad_length - len(x[fields])), new_field_name=fields, is_input=auto_set_input) elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), new_field_name=fields, is_input=auto_set_input) for data_name, data_set in data_info.datasets.items(): if isinstance(set_input, list): data_set.set_input(*[ inputs for inputs in set_input if inputs in data_set.get_field_names() ]) if isinstance(set_target, list): data_set.set_target(*[ target for target in set_target if target in data_set.get_field_names() ]) return data_info