コード例 #1
0
    def train(self):

        num_ex = self.config_params['num_valid_to_show']

        if num_ex > 0:
            print('Showing examples')
            preproc = self.config_params['preproc']
            show_ex_fn = preproc['show_ex']
            rlut1 = baseline.revlut(self.feat2index1['word'])
            rlut2 = baseline.revlut(self.feat2index2['word'])
            self.config_params['train']['after_train_fn'] = lambda model: show_ex_fn(model,
                                                                                     self.valid_data, rlut1, rlut2,
                                                                                     self.embeddings2['word'],
                                                                                     preproc['mxlen'], False, 0,
                                                                                     num_ex, reverse=False)
        super(EncoderDecoderTask, self).train()
コード例 #2
0
    def train(self, checkpoint=None):

        num_ex = self.config_params['num_valid_to_show']

        rlut1 = baseline.revlut(self.feat2src[self.primary_key])
        rlut2 = baseline.revlut(self.feat2tgt)
        if num_ex > 0:
            logger.info('Showing examples')
            preproc = self.config_params.get('preproc', {})
            show_ex_fn = preproc['show_ex']
            self.config_params['train']['after_train_fn'] = lambda model: show_ex_fn(model,
                                                                                     self.valid_data, rlut1, rlut2,
                                                                                     self.feat2tgt,
                                                                                     preproc['mxlen'], False, 0,
                                                                                     num_ex, reverse=False)
        self.config_params['train']['tgt_rlut'] = rlut2
        super(EncoderDecoderTask, self).train(checkpoint)
コード例 #3
0
def run(input_files=[], input_pattern='*.txt', codes=None, vocab=None, nctx=256, fmt='json', fields=['x_str', 'y_str'],
        output=None, prefix=None, suffix=None, max_file_size=100, tok_on_eol="<EOS>", cased=True,
        mask_type="mlm", module=None, pad_y=True, extra_tokens=['[CLS]', '[MASK]'], world_size=1, world_offset=0,
        input_field='text', tokenizer_type=None, **kwargs):

    def parse_json_line(x): return json.loads(x)[input_field]

    if module:
        logger.warning("Loading custom user module %s for masking rules and tokenizers", module)
        baseline.import_user_module(module)

    get_line = lambda x: x.strip()
    if os.path.isdir(input_files):
        if '.json' in input_pattern:
            get_line = parse_json_line
        input_files = list(glob.glob(os.path.join(input_files, input_pattern)))
        if not output:
            output = os.path.join(input_files, 'records')
    else:
        if '.json' in input_files:
            get_line = parse_json_line
        input_files = [input_files]
        if not output:
            output = f'{input_files}.records'

    if len(input_files) < world_size:
        raise Exception(f"The number of input shards ({len(input_files)})should be greater than the world_size: {world_size}")

    logger.info('Output [%s]', output)
    transform = baseline.lowercase if not cased else lambda x: x
    vectorizer = BPEVectorizer1D(transform_fn=transform, model_file=codes, vocab_file=vocab, mxlen=1024, extra_tokens=extra_tokens)

    lookup_indices = []
    indices2word = baseline.revlut(vectorizer.vocab)
    root_dir = os.path.dirname(output)
    tokenizer = create_tokenizer(tokenizer_type)
    masking = create_masking(mask_type, vectorizer.vocab, pad_y)
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)

    if prefix:
        nctx -= 1
        prefix = vectorizer.vocab[prefix]

    if suffix:
        nctx -= 1
        suffix = vectorizer.vocab[suffix]

    fw = create_file_writer(fmt, output, fields, max_file_size, 1000 * world_offset)
    num_samples = 0
    for i, text in enumerate(input_files):

        if i % world_size != world_offset:
            continue

        with TextFile(text) as rf:
            print(f"Reading from {text}...")
            for line in rf:
                to_bpe = tokenizer(get_line(line))
                if not to_bpe:
                    continue
                to_bpe += [tok_on_eol]

                output, available = vectorizer.run(to_bpe, vectorizer.vocab)
                while available > 0:
                    if len(lookup_indices) == nctx:
                        record = create_record(lookup_indices, indices2word, prefix, suffix, masking=masking)
                        fw.write(record)
                        num_samples += 1
                        lookup_indices = []
                    needed = nctx - len(lookup_indices)
                    if available >= needed:
                        lookup_indices += output[:needed].tolist()
                        output = output[needed:]
                        available -= needed
                        record = create_record(lookup_indices, indices2word, prefix, suffix, masking=masking)
                        fw.write(record)
                        num_samples += 1
                        lookup_indices = []
                    # The amount available is less than what we need, so read the whole thing
                    else:
                        lookup_indices += output[:available].tolist()
                        available = 0

    fw.close()
    f_name = f'md-{world_offset}.yml' if world_size > 1 else 'md.yml'
    write_yaml({'num_samples': num_samples}, os.path.join(root_dir, f_name))
コード例 #4
0
ファイル: preproc_tlm.py プロジェクト: blester125/baseline
        args.output = os.path.join(args.input_files, 'records')
else:
    input_files = [args.input_files]
    if not args.output:
        args.output = f'{args.input_files}.records'

print(args.output)
transform = baseline.lowercase if not args.cased else lambda x: x
vectorizer = BPEVectorizer1D(transform_fn=transform,
                             model_file=args.codes,
                             vocab_file=args.vocab,
                             mxlen=1024)

lookup_indices = []
words = []
indices2word = baseline.revlut(vectorizer.vocab)
vocab_size = max(vectorizer.vocab.values()) + 1
nctx = args.nctx
mask_value = vectorizer.vocab['[MASK]']
prefix = suffix = None
root_dir = os.path.dirname(args.output)
if not os.path.exists(root_dir):
    os.makedirs(root_dir)

if args.prefix:
    nctx -= 1
    prefix = vectorizer.vocab[args.prefix]

if args.suffix:
    nctx -= 1
    suffix = vectorizer.vocab[args.suffix]
コード例 #5
0
def run(input_files=[],
        input_pattern='*.txt',
        codes=None,
        vocab=None,
        nctx=256,
        fmt='json',
        fields=['x_str', 'y_str'],
        output=None,
        x_prefix=None,
        x_suffix=None,
        y_prefix=None,
        y_suffix=None,
        max_file_size=100,
        cased=True,
        mask_type="mlm",
        module=None,
        pad_y=True,
        extra_tokens=['[CLS]', '[MASK]'],
        tgt_nctx=None,
        world_size=1,
        world_offset=0,
        subword_type='bpe',
        **kwargs):
    timer = Timer()

    if module:
        logger.warning("Loading custom user module %s for masking rules",
                       module)
        baseline.import_user_module(module)

    if os.path.isdir(input_files):
        import glob
        input_files = list(glob.glob(os.path.join(input_files, input_pattern)))
        if not output:
            output = os.path.join(input_files, 'records')
    else:
        input_files = [input_files]
        if not output:
            output = f'{input_files}.records'

    logger.info('Output [%s]', output)
    if not tgt_nctx:
        tgt_nctx = 64
    transform = baseline.lowercase if not cased else lambda x: x
    Vec1D = get_subword_vec1d(subword_type)
    vectorizer = Vec1D(transform_fn=transform,
                       model_file=codes,
                       vocab_file=vocab,
                       mxlen=1024,
                       extra_tokens=extra_tokens)

    if x_prefix:
        x_prefix = vectorizer.vocab[x_prefix]
    if x_suffix:
        x_suffix = vectorizer.vocab[x_suffix]
    if y_prefix:
        y_prefix = vectorizer.vocab[y_prefix]
    if y_suffix:
        y_suffix = vectorizer.vocab[y_suffix]

    indices2word = baseline.revlut(vectorizer.vocab)
    root_dir = os.path.dirname(output)
    masking = create_masking(mask_type, vectorizer.vocab, pad_y)
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)

    # Create a file writer for this shard
    fw = create_file_writer(fmt, output, fields, max_file_size,
                            1000 * world_offset)
    num_read = -1
    num_samples_this_worker = 0

    for text in input_files:
        with open(text, encoding='utf-8') as rf:
            print(f"Reading from {text}...")
            for line in rf:
                num_read += 1
                if num_read % world_size != world_offset:
                    continue

                to_bpe = line.strip().split()
                if not to_bpe:
                    continue

                output, available = vectorizer.run(to_bpe, vectorizer.vocab)
                x, y = masking(output[:available], False, False)
                if x_prefix:
                    x = [x_prefix] + x
                if y_prefix:
                    y = [y_prefix] + y
                if x_suffix:
                    x += [x_suffix]
                if y_suffix:
                    y += [y_suffix]

                x = x[:nctx]
                y = y[:tgt_nctx]
                x_t = np.zeros(nctx, dtype=output.dtype)
                y_t = np.zeros(tgt_nctx, dtype=output.dtype)
                x_t[:len(x)] = x
                y_t[:len(y)] = y
                record = {
                    'x': x_t,
                    'y': y_t,
                    'x_str': [indices2word[s] for s in x_t],
                    'y_str': [indices2word[s] for s in y_t]
                }
                if masking.is_valid(record):
                    fw.write(record)
                    num_samples_this_worker += 1

    fw.close()
    duration = timer.elapsed()
    print("Processed {:,} samples in {:.2f}s".format(num_samples_this_worker,
                                                     duration))
    f_name = f'md-{world_offset}.yml' if world_size > 1 else 'md.yml'
    write_yaml({'num_samples': num_samples_this_worker},
               os.path.join(root_dir, f_name))
コード例 #6
0
def write_files(annot_files, doc_files, fw, output_dir, pg_name):
    num_samples = 0
    indices2word = baseline.revlut(VECTORIZER.vocab)
    indices2labels = baseline.revlut(LABELS)
    lookup_indices = []
    pg = create_progress_bar(len(annot_files), name=pg_name)

    for annot in pg(annot_files):
        doc = os.path.join(doc_files, annot.name)
        assert (os.path.exists(doc))
        td = dict_doc(doc, DOC2WORD)
        ad = annot_doc(annot)
        # For each document
        for doc_id in ad.keys():
            yd = []
            this_doc = td[doc_id]
            for sent in this_doc:
                yd.append(['O'] * len(sent))
            this_annot = ad[doc_id]
            for annotation in this_annot:
                sid, start, end, label = annotation
                label = label2word[label]
                if (start + 1) >= end:
                    yd[sid][start] = f"S-{label}"
                else:
                    yd[sid][start] = f"B-{label}"
                    yd[sid][end - 1] = f"E-{label}"
                    for k in range(start + 1, end - 1):
                        yd[sid][k] = f"I-{label}"

            # For each document, run BPE over the whole thing
            for j, sentence in enumerate(this_doc):
                output = [
                    pair for pair in convert_to_pairs(VECTORIZER, sentence,
                                                      yd[j], LABELS)
                ]
                available = len(output)

                while available > 0:
                    if len(lookup_indices) == NCTX:
                        record = create_record(lookup_indices, indices2word,
                                               indices2labels, PREFIX, SUFFIX)
                        fw.write(record)
                        num_samples += 1
                        lookup_indices = []
                    needed = NCTX - len(lookup_indices)
                    if available >= needed:
                        lookup_indices += output[:needed]
                        output = output[needed:]
                        available -= needed
                        record = create_record(lookup_indices, indices2word,
                                               indices2labels, PREFIX, SUFFIX)
                        fw.write(record)
                        num_samples += 1
                        lookup_indices = []
                    # The amount available is less than what we need, so read the whole thing
                    else:
                        lookup_indices += output[:available]
                        available = 0
    fw.close()
    write_yaml({'num_samples': num_samples},
               os.path.join(output_dir, 'md.yml'))