Ejemplo n.º 1
0
    def make_all(lang, vocab):
        for l in os.listdir(args['preprocess']['trainpref'].split('*')[0]):
            # copy shared dict into each languages
            out_dir = os.path.join(args['preprocess']['destdir'], l)
            PathManager.mkdir(out_dir)
            dst_dict = os.path.join(out_dir, f'{lang}.dict.jsonl')
            PathManager.copy(dict_path(lang), dst_dict)

            if args['preprocess']['trainpref']:
                out_file = os.path.join(out_dir, f"train.{lang}")
                make_dataset(vocab,
                             args['preprocess']['trainpref'].replace('*', l),
                             "train",
                             lang,
                             out_file=out_file,
                             num_workers=args['preprocess']['workers'])
            if args['preprocess']['validpref']:
                out_file = os.path.join(out_dir, f"valid.{lang}")
                make_dataset(vocab,
                             args['preprocess']['validpref'].replace('*', l),
                             'valid',
                             lang,
                             out_file=out_file,
                             num_workers=args['preprocess']['workers'])
            if args['preprocess']['testpref']:
                out_file = os.path.join(out_dir, f"test.{lang}")
                make_dataset(vocab,
                             args['preprocess']['testpref'].replace('*', l),
                             'test',
                             lang,
                             out_file=out_file,
                             num_workers=args['preprocess']['workers'])
Ejemplo n.º 2
0
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    from ncc import meters
    from ncc.utils import distributed_utils
    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        best_function = max if args['checkpoint'][
            'maximize_best_checkpoint_metric'] else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if args['checkpoint']['no_save'] or not distributed_utils.is_master(args):
        return

    def is_better(a, b):
        return a >= b if args['checkpoint'][
            'maximize_best_checkpoint_metric'] else a <= b

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
        end_of_epoch and not args['checkpoint']['no_epoch_checkpoints']
        and epoch % args['checkpoint']['save_interval'] == 0)
    checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = (
        not end_of_epoch and args['checkpoint']['save_interval_updates'] > 0
        and updates % args['checkpoint']['save_interval_updates'] == 0)
    checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
        not hasattr(save_checkpoint, "best")
        or is_better(val_loss, save_checkpoint.best))
    if val_loss is not None and args['checkpoint']['keep_best_checkpoints'] > 0:
        checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
            args['checkpoint']['best_checkpoint_metric'],
            val_loss)] = (not hasattr(save_checkpoint, "best")
                          or is_better(val_loss, save_checkpoint.best))
    checkpoint_conds[
        "checkpoint_last.pt"] = not args['checkpoint']['no_last_checkpoints']

    extra_state = {
        "train_iterator": epoch_itr.state_dict(),
        "val_loss": val_loss
    }
    if hasattr(save_checkpoint, "best"):
        extra_state.update({"best": save_checkpoint.best})

    checkpoints = [
        os.path.join(args['checkpoint']['save_dir'], fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            PathManager.copy(checkpoints[0], cp)

        write_timer.stop()
        LOGGER.info(
            "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {:.6f} seconds)"
            .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum))

    if not end_of_epoch and args['checkpoint']['keep_interval_updates'] > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(args['checkpoint']['save_dir'],
                                       pattern=r"checkpoint_\d+_(\d+)\.pt")
        for old_chk in checkpoints[
                args['checkpoint']['keep_interval_updates']:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args['checkpoint']['keep_last_epochs'] > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(args['checkpoint']['save_dir'],
                                       pattern=r"checkpoint(\d+)\.pt")
        for old_chk in checkpoints[args['checkpoint']['keep_last_epochs']:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args['checkpoint']['keep_best_checkpoints'] > 0:
        # only keep the best N checkpoints according to validation metric
        checkpoints = checkpoint_paths(
            args['checkpoint']['save_dir'],
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
                args['checkpoint']['best_checkpoint_metric']))
        if not args['checkpoint']['maximize_best_checkpoint_metric']:
            checkpoints = checkpoints[::-1]
        for old_chk in checkpoints[
                args['checkpoint']['keep_best_checkpoints']:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
Ejemplo n.º 3
0
def main(args):
    # task = tasks.get_task(args['preprocess']['task'])
    LOGGER.info('mkdir {} for {} task'.format(args['preprocess']['destdir'],
                                              args['preprocess']['task']))
    PathManager.mkdir(args['preprocess']['destdir'])
    vocab = TransformersDictionary.from_pretrained('microsoft/codebert-base',
                                                   do_lower_case=False)

    max_source_length, max_target_length = args['preprocess']['max_source_length'], \
                                           args['preprocess']['max_target_length']

    # 2. ***************build dataset********************
    # dump into pkl file
    # transform a language's code into src format and tgt format simualtaneouly
    def parse_source_input(code):
        code_tokens = vocab.tokenize(code)
        # truncating
        code_tokens = code_tokens[:max_source_length - 2]
        source_tokens = [vocab.cls_token] + code_tokens + [vocab.sep_token]
        source_ids = vocab.convert_tokens_to_ids(source_tokens)
        source_size = len(source_tokens)
        source_mask = [1] * source_size
        padding_length = max_source_length - len(source_ids)
        source_ids += [vocab.pad()] * padding_length
        source_mask += [0] * padding_length
        return [source_ids, source_mask, source_size]

    def parse_target_input(code):
        target_tokens = vocab.tokenize(code)[:max_target_length - 2]
        target_tokens = [vocab.cls_token] + target_tokens + [vocab.sep_token]
        target_ids = vocab.convert_tokens_to_ids(target_tokens)
        target_size = len(target_ids)
        target_mask = [1] * target_size
        padding_length = max_target_length - len(target_ids)
        target_ids += [vocab.pad_token_id] * padding_length
        target_mask += [0] * padding_length
        return [target_ids, target_mask, target_size]

    src_lang, tgt_lang = args['preprocess']['src_lang'], args['preprocess'][
        'tgt_lang']
    for lang, mode in itertools.product([src_lang, tgt_lang], MODES):
        # cp id
        src_id = args['preprocess'][f'{mode}pref'].replace('*', '') + ".id"
        tgt_id = os.path.join(args['preprocess']['destdir'], f"{mode}.id")
        PathManager.copy(src_id, tgt_id)

        src_file = args['preprocess'][f'{mode}pref'].replace('*',
                                                             lang) + ".code"
        dst_file = os.path.join(args['preprocess']['destdir'], lang,
                                f"{mode}.pkl")
        PathManager.mkdir(os.path.dirname(dst_file))
        with file_io.open(src_file, 'r') as reader:
            keys = [
                'code', 'src_tokens', 'src_masks', 'src_sizes', 'tgt_tokens',
                'tgt_masks', 'tgt_sizes'
            ]
            data = {key: [] for key in keys}
            for line in reader:
                src_code = json_io.json_loads(line)
                # src_code = SPACE_SPLITTER.sub(" ", line)
                # source_ids, source_mask
                src_line = parse_source_input(src_code)
                # target_ids, target_mask
                tgt_line = parse_target_input(src_code)
                for key, src in zip(keys, [src_code] + src_line + tgt_line):
                    data[key].append(src)
            file_io.open(dst_file, mode='wb', data=data)
Ejemplo n.º 4
0
def main(args):
    LOGGER.info('mkdir {} for {} task'.format(args['preprocess']['destdir'],
                                              args['preprocess']['task']))
    PathManager.mkdir(args['preprocess']['destdir'])

    SPM_VOCAB_FILE = os.path.join(BPE_DIR, 'plbart', "sentencepiece.bpe.model")
    vocab = spm.SentencePieceProcessor()
    vocab.load(SPM_VOCAB_FILE)

    def save_token_dict():
        src_file = os.path.join(os.path.dirname(SPM_VOCAB_FILE), 'dict.txt')
        tgt_file = os.path.join(args['preprocess']['destdir'], 'dict.jsonl')
        # Dictionary.text_to_jsonl(src_file, tgt_file)
        vocab = Dictionary()
        with file_io.open(src_file, 'r') as reader:
            for line in reader:
                token, num = line.strip().split()
                vocab.add_symbol(token, eval(num))
        vocab.save(tgt_file)
        return vocab

    token_dict = save_token_dict()

    # 2. ***************build dataset********************
    # dump into pkl file
    # transform a language's code into src format and tgt format simualtaneouly
    num_workers = args['preprocess']['workers']
    src_lang, tgt_lang = args['preprocess']['src_lang'], args['preprocess'][
        'tgt_lang']

    for lang, mode in itertools.product([src_lang, tgt_lang], MODES):
        # cp id
        src_id = args['preprocess'][f'{mode}pref'].replace('*', '') + ".id"
        tgt_id = os.path.join(args['preprocess']['destdir'], f"{mode}.id")
        PathManager.copy(src_id, tgt_id)

        src_file = args['preprocess'][f'{mode}pref'].replace('*',
                                                             lang) + ".code"
        dst_file = os.path.join(args['preprocess']['destdir'], lang,
                                f"{mode}.code_tokens")
        PathManager.mkdir(os.path.dirname(dst_file))

        offsets = find_offsets(src_file, num_workers)
        pool = None
        if num_workers > 1:
            # p1-pN -> (1 bin-txt, 1 idx), (N bin-txt, N idx)
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(dst_file, worker_id)
                pool.apply_async(
                    binarize,
                    (args, src_file, prefix, vocab, token_dict,
                     offsets[worker_id], offsets[worker_id + 1]),
                )
            pool.close()

        ds = indexed_dataset.make_builder(f"{dst_file}.mmap",
                                          impl='mmap',
                                          vocab_size=len(vocab))
        end = offsets[1]

        with file_io.open(src_file, 'r') as reader:
            reader.seek(0)
            line = file_io.safe_readline(reader)
            while line:
                if end > 0 and reader.tell() > end:
                    break
                line = json_io.json_loads(line)
                code_tokens = vocab.encode(line, out_type=str)
                code_tokens = torch.IntTensor(
                    [token_dict.index(token) for token in code_tokens])
                ds.add_item(code_tokens)
                line = reader.readline()

        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                temp_file_path = "{}{}".format(dst_file, worker_id)
                ds.merge_file_(temp_file_path)
                # idx, txt
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))
        ds.finalize(f"{dst_file}.idx")
Ejemplo n.º 5
0
def main(args):
    # task = tasks.get_task(args['preprocess']['task'])
    LOGGER.info('mkdir {} for {} task'.format(args['preprocess']['destdir'],
                                              args['preprocess']['task']))
    PathManager.mkdir(args['preprocess']['destdir'])
    vocab = TransformersDictionary.from_pretrained(
        'microsoft/graphcodebert-base')

    max_source_length, max_target_length = args['preprocess']['max_source_length'], \
                                           args['preprocess']['max_target_length']

    # 2. ***************build dataset********************
    # dump into pkl file
    # transform a language's code into src format and tgt format simualtaneouly
    def parse_source_input(code, lang):
        code_tokens, dfg = extract_dataflow(code, parsers[lang], lang)
        code_tokens = vocab.subtokenize(code_tokens)

        ori2cur_pos = {}
        ori2cur_pos[-1] = (0, 0)
        for i in range(len(code_tokens)):
            ori2cur_pos[i] = (ori2cur_pos[i - 1][1],
                              ori2cur_pos[i - 1][1] + len(code_tokens[i]))

        # truncating
        code_tokens = code_tokens[:max_source_length - 3][:512 - 3]
        source_tokens = [vocab.cls_token] + code_tokens + [vocab.sep_token]
        source_ids = vocab.convert_tokens_to_ids(source_tokens)
        position_idx = [i + vocab.pad() + 1 for i in range(len(source_tokens))]
        dfg = dfg[:max_source_length - len(source_tokens)]
        source_tokens += [x[0] for x in dfg]
        position_idx += [0 for _ in dfg]
        source_ids += [vocab.unk() for _ in dfg]
        padding_length = max_source_length - len(source_ids)
        position_idx += [vocab.pad()] * padding_length
        source_ids += [vocab.pad()] * padding_length
        source_mask = [1] * (len(source_tokens))
        source_mask += [0] * padding_length

        # reindex
        reverse_index = {}
        for idx, x in enumerate(dfg):
            reverse_index[x[1]] = idx
        for idx, x in enumerate(dfg):
            dfg[idx] = x[:-1] + (
                [reverse_index[i] for i in x[-1] if i in reverse_index], )
        dfg_to_dfg = [x[-1] for x in dfg]
        dfg_to_code = [ori2cur_pos[x[1]] for x in dfg]
        length = len([vocab.cls()])
        dfg_to_code = [(x[0] + length, x[1] + length) for x in dfg_to_code]
        return [source_ids, position_idx, dfg_to_code, dfg_to_dfg, source_mask]

    def parse_target_input(code):
        target_tokens = vocab.tokenize(code)[:max_target_length - 2]
        target_tokens = [vocab.cls_token] + target_tokens + [vocab.sep_token]
        target_ids = vocab.convert_tokens_to_ids(target_tokens)
        target_mask = [1] * len(target_ids)
        padding_length = max_target_length - len(target_ids)
        target_ids += [vocab.pad_token_id] * padding_length
        target_mask += [0] * padding_length
        return [target_ids, target_mask]

    src_lang, tgt_lang = args['preprocess']['src_lang'], args['preprocess'][
        'tgt_lang']
    for lang, mode in itertools.product([src_lang, tgt_lang], MODES):
        # cp id
        src_id = args['preprocess'][f'{mode}pref'].replace('*', '') + ".id"
        tgt_id = os.path.join(args['preprocess']['destdir'], f"{mode}.id")
        PathManager.copy(src_id, tgt_id)

        src_file = args['preprocess'][f'{mode}pref'].replace('*',
                                                             lang) + ".code"
        dst_file = os.path.join(args['preprocess']['destdir'], lang,
                                f"{mode}.pkl")
        PathManager.mkdir(os.path.dirname(dst_file))
        with file_io.open(src_file, 'r') as reader:
            keys = [
                'code',
                'src_tokens',
                'src_positions',
                'dfg2code',
                'dfg2dfg',
                'src_masks',
                'tgt_tokens',
                'tgt_masks',
            ]
            data = {key: [] for key in keys}
            for line in reader:
                src_code = json_io.json_loads(line)
                # src_code = SPACE_SPLITTER.sub(" ", line)
                # source_ids, position_idx, dfg_to_code, dfg_to_dfg, source_mask
                src_line = parse_source_input(src_code, lang)
                # target_ids, target_mask
                tgt_line = parse_target_input(src_code)
                for key, src in zip(keys, [src_code] + src_line + tgt_line):
                    data[key].append(src)
            file_io.open(dst_file, mode='wb', data=data)