def train(self): num_ex = self.config_params['num_valid_to_show'] if num_ex > 0: print('Showing examples') preproc = self.config_params['preproc'] show_ex_fn = preproc['show_ex'] rlut1 = baseline.revlut(self.feat2index1['word']) rlut2 = baseline.revlut(self.feat2index2['word']) self.config_params['train']['after_train_fn'] = lambda model: show_ex_fn(model, self.valid_data, rlut1, rlut2, self.embeddings2['word'], preproc['mxlen'], False, 0, num_ex, reverse=False) super(EncoderDecoderTask, self).train()
def train(self, checkpoint=None): num_ex = self.config_params['num_valid_to_show'] rlut1 = baseline.revlut(self.feat2src[self.primary_key]) rlut2 = baseline.revlut(self.feat2tgt) if num_ex > 0: logger.info('Showing examples') preproc = self.config_params.get('preproc', {}) show_ex_fn = preproc['show_ex'] self.config_params['train']['after_train_fn'] = lambda model: show_ex_fn(model, self.valid_data, rlut1, rlut2, self.feat2tgt, preproc['mxlen'], False, 0, num_ex, reverse=False) self.config_params['train']['tgt_rlut'] = rlut2 super(EncoderDecoderTask, self).train(checkpoint)
def run(input_files=[], input_pattern='*.txt', codes=None, vocab=None, nctx=256, fmt='json', fields=['x_str', 'y_str'], output=None, prefix=None, suffix=None, max_file_size=100, tok_on_eol="<EOS>", cased=True, mask_type="mlm", module=None, pad_y=True, extra_tokens=['[CLS]', '[MASK]'], world_size=1, world_offset=0, input_field='text', tokenizer_type=None, **kwargs): def parse_json_line(x): return json.loads(x)[input_field] if module: logger.warning("Loading custom user module %s for masking rules and tokenizers", module) baseline.import_user_module(module) get_line = lambda x: x.strip() if os.path.isdir(input_files): if '.json' in input_pattern: get_line = parse_json_line input_files = list(glob.glob(os.path.join(input_files, input_pattern))) if not output: output = os.path.join(input_files, 'records') else: if '.json' in input_files: get_line = parse_json_line input_files = [input_files] if not output: output = f'{input_files}.records' if len(input_files) < world_size: raise Exception(f"The number of input shards ({len(input_files)})should be greater than the world_size: {world_size}") logger.info('Output [%s]', output) transform = baseline.lowercase if not cased else lambda x: x vectorizer = BPEVectorizer1D(transform_fn=transform, model_file=codes, vocab_file=vocab, mxlen=1024, extra_tokens=extra_tokens) lookup_indices = [] indices2word = baseline.revlut(vectorizer.vocab) root_dir = os.path.dirname(output) tokenizer = create_tokenizer(tokenizer_type) masking = create_masking(mask_type, vectorizer.vocab, pad_y) if not os.path.exists(root_dir): os.makedirs(root_dir) if prefix: nctx -= 1 prefix = vectorizer.vocab[prefix] if suffix: nctx -= 1 suffix = vectorizer.vocab[suffix] fw = create_file_writer(fmt, output, fields, max_file_size, 1000 * world_offset) num_samples = 0 for i, text in enumerate(input_files): if i % world_size != world_offset: continue with TextFile(text) as rf: print(f"Reading from {text}...") for line in rf: to_bpe = tokenizer(get_line(line)) if not to_bpe: continue to_bpe += [tok_on_eol] output, available = vectorizer.run(to_bpe, vectorizer.vocab) while available > 0: if len(lookup_indices) == nctx: record = create_record(lookup_indices, indices2word, prefix, suffix, masking=masking) fw.write(record) num_samples += 1 lookup_indices = [] needed = nctx - len(lookup_indices) if available >= needed: lookup_indices += output[:needed].tolist() output = output[needed:] available -= needed record = create_record(lookup_indices, indices2word, prefix, suffix, masking=masking) fw.write(record) num_samples += 1 lookup_indices = [] # The amount available is less than what we need, so read the whole thing else: lookup_indices += output[:available].tolist() available = 0 fw.close() f_name = f'md-{world_offset}.yml' if world_size > 1 else 'md.yml' write_yaml({'num_samples': num_samples}, os.path.join(root_dir, f_name))
args.output = os.path.join(args.input_files, 'records') else: input_files = [args.input_files] if not args.output: args.output = f'{args.input_files}.records' print(args.output) transform = baseline.lowercase if not args.cased else lambda x: x vectorizer = BPEVectorizer1D(transform_fn=transform, model_file=args.codes, vocab_file=args.vocab, mxlen=1024) lookup_indices = [] words = [] indices2word = baseline.revlut(vectorizer.vocab) vocab_size = max(vectorizer.vocab.values()) + 1 nctx = args.nctx mask_value = vectorizer.vocab['[MASK]'] prefix = suffix = None root_dir = os.path.dirname(args.output) if not os.path.exists(root_dir): os.makedirs(root_dir) if args.prefix: nctx -= 1 prefix = vectorizer.vocab[args.prefix] if args.suffix: nctx -= 1 suffix = vectorizer.vocab[args.suffix]
def run(input_files=[], input_pattern='*.txt', codes=None, vocab=None, nctx=256, fmt='json', fields=['x_str', 'y_str'], output=None, x_prefix=None, x_suffix=None, y_prefix=None, y_suffix=None, max_file_size=100, cased=True, mask_type="mlm", module=None, pad_y=True, extra_tokens=['[CLS]', '[MASK]'], tgt_nctx=None, world_size=1, world_offset=0, subword_type='bpe', **kwargs): timer = Timer() if module: logger.warning("Loading custom user module %s for masking rules", module) baseline.import_user_module(module) if os.path.isdir(input_files): import glob input_files = list(glob.glob(os.path.join(input_files, input_pattern))) if not output: output = os.path.join(input_files, 'records') else: input_files = [input_files] if not output: output = f'{input_files}.records' logger.info('Output [%s]', output) if not tgt_nctx: tgt_nctx = 64 transform = baseline.lowercase if not cased else lambda x: x Vec1D = get_subword_vec1d(subword_type) vectorizer = Vec1D(transform_fn=transform, model_file=codes, vocab_file=vocab, mxlen=1024, extra_tokens=extra_tokens) if x_prefix: x_prefix = vectorizer.vocab[x_prefix] if x_suffix: x_suffix = vectorizer.vocab[x_suffix] if y_prefix: y_prefix = vectorizer.vocab[y_prefix] if y_suffix: y_suffix = vectorizer.vocab[y_suffix] indices2word = baseline.revlut(vectorizer.vocab) root_dir = os.path.dirname(output) masking = create_masking(mask_type, vectorizer.vocab, pad_y) if not os.path.exists(root_dir): os.makedirs(root_dir) # Create a file writer for this shard fw = create_file_writer(fmt, output, fields, max_file_size, 1000 * world_offset) num_read = -1 num_samples_this_worker = 0 for text in input_files: with open(text, encoding='utf-8') as rf: print(f"Reading from {text}...") for line in rf: num_read += 1 if num_read % world_size != world_offset: continue to_bpe = line.strip().split() if not to_bpe: continue output, available = vectorizer.run(to_bpe, vectorizer.vocab) x, y = masking(output[:available], False, False) if x_prefix: x = [x_prefix] + x if y_prefix: y = [y_prefix] + y if x_suffix: x += [x_suffix] if y_suffix: y += [y_suffix] x = x[:nctx] y = y[:tgt_nctx] x_t = np.zeros(nctx, dtype=output.dtype) y_t = np.zeros(tgt_nctx, dtype=output.dtype) x_t[:len(x)] = x y_t[:len(y)] = y record = { 'x': x_t, 'y': y_t, 'x_str': [indices2word[s] for s in x_t], 'y_str': [indices2word[s] for s in y_t] } if masking.is_valid(record): fw.write(record) num_samples_this_worker += 1 fw.close() duration = timer.elapsed() print("Processed {:,} samples in {:.2f}s".format(num_samples_this_worker, duration)) f_name = f'md-{world_offset}.yml' if world_size > 1 else 'md.yml' write_yaml({'num_samples': num_samples_this_worker}, os.path.join(root_dir, f_name))
def write_files(annot_files, doc_files, fw, output_dir, pg_name): num_samples = 0 indices2word = baseline.revlut(VECTORIZER.vocab) indices2labels = baseline.revlut(LABELS) lookup_indices = [] pg = create_progress_bar(len(annot_files), name=pg_name) for annot in pg(annot_files): doc = os.path.join(doc_files, annot.name) assert (os.path.exists(doc)) td = dict_doc(doc, DOC2WORD) ad = annot_doc(annot) # For each document for doc_id in ad.keys(): yd = [] this_doc = td[doc_id] for sent in this_doc: yd.append(['O'] * len(sent)) this_annot = ad[doc_id] for annotation in this_annot: sid, start, end, label = annotation label = label2word[label] if (start + 1) >= end: yd[sid][start] = f"S-{label}" else: yd[sid][start] = f"B-{label}" yd[sid][end - 1] = f"E-{label}" for k in range(start + 1, end - 1): yd[sid][k] = f"I-{label}" # For each document, run BPE over the whole thing for j, sentence in enumerate(this_doc): output = [ pair for pair in convert_to_pairs(VECTORIZER, sentence, yd[j], LABELS) ] available = len(output) while available > 0: if len(lookup_indices) == NCTX: record = create_record(lookup_indices, indices2word, indices2labels, PREFIX, SUFFIX) fw.write(record) num_samples += 1 lookup_indices = [] needed = NCTX - len(lookup_indices) if available >= needed: lookup_indices += output[:needed] output = output[needed:] available -= needed record = create_record(lookup_indices, indices2word, indices2labels, PREFIX, SUFFIX) fw.write(record) num_samples += 1 lookup_indices = [] # The amount available is less than what we need, so read the whole thing else: lookup_indices += output[:available] available = 0 fw.close() write_yaml({'num_samples': num_samples}, os.path.join(output_dir, 'md.yml'))