예제 #1
0
    def add_args(cls, parser: ArgumentParser):
        TableBertConfig.add_args(parser)

        parser.add_argument("--num_vertical_attention_heads",
                            type=int,
                            default=6)
        parser.add_argument("--num_vertical_layers", type=int, default=3)
        parser.add_argument("--sample_row_num", type=int, default=3)
        parser.add_argument("--predict_cell_tokens",
                            action='store_true',
                            dest='predict_cell_tokens')
        parser.add_argument("--no_predict_cell_tokens",
                            action='store_false',
                            dest='predict_cell_tokens')
        parser.set_defaults(predict_cell_tokens=False)

        parser.add_argument("--initialize_from", type=Path, default=None)
예제 #2
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--train_corpus', type=Path, required=True)
    parser.add_argument("--output_dir", type=Path, required=True)
    parser.add_argument("--epochs_to_generate", type=int, default=3,
                        help="Number of epochs of preprocess to pregenerate")
    parser.add_argument('--no_wiki_tables_from_common_crawl', action='store_true', default=False)
    parser.add_argument('--global_rank', type=int, default=os.environ.get('SLURM_PROCID', 0))
    parser.add_argument('--world_size', type=int, default=os.environ.get('SLURM_NTASKS', 1))

    ## YS
    parser.add_argument('--use_acoustic_confusion', action='store_true', default=False)
    parser.add_argument('--acoustic_confusion_prob', type=float, default=0.15,
                        help="Probability of replacing a token with a confused one")
    # parser.add_argument('--acoustic_confusion_type', type=str, choices=['random', 'gpt2'], default='random')
    parser.add_argument('--word_confusion_path', type=Path, default='')
    
    TableBertConfig.add_args(parser)

    args = parser.parse_args()
    args.is_master = args.global_rank == 0

    logger = logging.getLogger('DataGenerator')
    handler = logging.StreamHandler(sys.stderr)
    logger.addHandler(handler)
    logger.setLevel(logging.DEBUG)

    logger.info(f'Rank {args.global_rank} out of {args.world_size}')
    sys.stderr.flush()

    table_bert_config = TableBertConfig.from_dict(vars(args))
    tokenizer = BertTokenizer.from_pretrained(table_bert_config.base_model_name)

    ## YS
    if args.use_acoustic_confusion:
        assert args.word_confusion_path != ''

        acoustic_confuser = SentenceAcousticConfuser_RandomReplace(args.word_confusion_path, default_p=args.acoustic_confusion_prob)
        input_formatter = VanillaTableBertInputFormatterWithConfusion(table_bert_config, tokenizer, acoustic_confuser)
    else:
        input_formatter = VanillaTableBertInputFormatter(table_bert_config, tokenizer)

    total_tables_num = int(subprocess.check_output(f"wc -l {args.train_corpus}", shell=True).split()[0])
    dev_table_num = min(int(total_tables_num * 0.1), 100000)
    train_table_num = total_tables_num - dev_table_num

    # seed the RNG to make sure each process follows the same spliting
    rng = np.random.RandomState(seed=5783287)

    corpus_table_indices = list(range(total_tables_num))
    rng.shuffle(corpus_table_indices)
    dev_table_indices = corpus_table_indices[:dev_table_num]
    train_table_indices = corpus_table_indices[dev_table_num:]

    local_dev_table_indices = dev_table_indices[args.global_rank::args.world_size]
    local_train_table_indices = train_table_indices[args.global_rank::args.world_size]
    local_indices = local_dev_table_indices + local_train_table_indices

    logger.info(f'total tables: {total_tables_num}')
    logger.debug(f'local dev table indices: {local_dev_table_indices[:1000]}')
    logger.debug(f'local train table indices: {local_train_table_indices[:1000]}')

    with TableDatabase.from_jsonl(args.train_corpus, backend='memory', tokenizer=tokenizer, indices=local_indices) as table_db:
        local_indices = {idx for idx in local_indices if idx in table_db}
        local_dev_table_indices = [idx for idx in local_dev_table_indices if idx in local_indices]
        local_train_table_indices = [idx for idx in local_train_table_indices if idx in local_indices]

        args.output_dir.mkdir(exist_ok=True, parents=True)
        print(f'Num tables to be processed by local worker: {len(table_db)}', file=sys.stdout)

        if args.is_master:
            with (args.output_dir / 'config.json').open('w') as f:
                json.dump(vars(args), f, indent=2, sort_keys=True, default=str)

        (args.output_dir / 'train').mkdir(exist_ok=True)
        (args.output_dir / 'dev').mkdir(exist_ok=True)

        # generate dev data first
        dev_file = args.output_dir / 'dev' / f'epoch_0.shard{args.global_rank}.h5'
        generate_for_epoch(table_db, local_dev_table_indices, dev_file, input_formatter, args)

        for epoch in trange(args.epochs_to_generate, desc='Epoch'):
            gc.collect()
            epoch_filename = args.output_dir / 'train' / f"epoch_{epoch}.shard{args.global_rank}.h5"
            generate_for_epoch(table_db, local_train_table_indices, epoch_filename, input_formatter, args)