def add_args(cls, parser: ArgumentParser): TableBertConfig.add_args(parser) parser.add_argument("--num_vertical_attention_heads", type=int, default=6) parser.add_argument("--num_vertical_layers", type=int, default=3) parser.add_argument("--sample_row_num", type=int, default=3) parser.add_argument("--predict_cell_tokens", action='store_true', dest='predict_cell_tokens') parser.add_argument("--no_predict_cell_tokens", action='store_false', dest='predict_cell_tokens') parser.set_defaults(predict_cell_tokens=False) parser.add_argument("--initialize_from", type=Path, default=None)
def main(): parser = ArgumentParser() parser.add_argument('--train_corpus', type=Path, required=True) parser.add_argument("--output_dir", type=Path, required=True) parser.add_argument("--epochs_to_generate", type=int, default=3, help="Number of epochs of preprocess to pregenerate") parser.add_argument('--no_wiki_tables_from_common_crawl', action='store_true', default=False) parser.add_argument('--global_rank', type=int, default=os.environ.get('SLURM_PROCID', 0)) parser.add_argument('--world_size', type=int, default=os.environ.get('SLURM_NTASKS', 1)) ## YS parser.add_argument('--use_acoustic_confusion', action='store_true', default=False) parser.add_argument('--acoustic_confusion_prob', type=float, default=0.15, help="Probability of replacing a token with a confused one") # parser.add_argument('--acoustic_confusion_type', type=str, choices=['random', 'gpt2'], default='random') parser.add_argument('--word_confusion_path', type=Path, default='') TableBertConfig.add_args(parser) args = parser.parse_args() args.is_master = args.global_rank == 0 logger = logging.getLogger('DataGenerator') handler = logging.StreamHandler(sys.stderr) logger.addHandler(handler) logger.setLevel(logging.DEBUG) logger.info(f'Rank {args.global_rank} out of {args.world_size}') sys.stderr.flush() table_bert_config = TableBertConfig.from_dict(vars(args)) tokenizer = BertTokenizer.from_pretrained(table_bert_config.base_model_name) ## YS if args.use_acoustic_confusion: assert args.word_confusion_path != '' acoustic_confuser = SentenceAcousticConfuser_RandomReplace(args.word_confusion_path, default_p=args.acoustic_confusion_prob) input_formatter = VanillaTableBertInputFormatterWithConfusion(table_bert_config, tokenizer, acoustic_confuser) else: input_formatter = VanillaTableBertInputFormatter(table_bert_config, tokenizer) total_tables_num = int(subprocess.check_output(f"wc -l {args.train_corpus}", shell=True).split()[0]) dev_table_num = min(int(total_tables_num * 0.1), 100000) train_table_num = total_tables_num - dev_table_num # seed the RNG to make sure each process follows the same spliting rng = np.random.RandomState(seed=5783287) corpus_table_indices = list(range(total_tables_num)) rng.shuffle(corpus_table_indices) dev_table_indices = corpus_table_indices[:dev_table_num] train_table_indices = corpus_table_indices[dev_table_num:] local_dev_table_indices = dev_table_indices[args.global_rank::args.world_size] local_train_table_indices = train_table_indices[args.global_rank::args.world_size] local_indices = local_dev_table_indices + local_train_table_indices logger.info(f'total tables: {total_tables_num}') logger.debug(f'local dev table indices: {local_dev_table_indices[:1000]}') logger.debug(f'local train table indices: {local_train_table_indices[:1000]}') with TableDatabase.from_jsonl(args.train_corpus, backend='memory', tokenizer=tokenizer, indices=local_indices) as table_db: local_indices = {idx for idx in local_indices if idx in table_db} local_dev_table_indices = [idx for idx in local_dev_table_indices if idx in local_indices] local_train_table_indices = [idx for idx in local_train_table_indices if idx in local_indices] args.output_dir.mkdir(exist_ok=True, parents=True) print(f'Num tables to be processed by local worker: {len(table_db)}', file=sys.stdout) if args.is_master: with (args.output_dir / 'config.json').open('w') as f: json.dump(vars(args), f, indent=2, sort_keys=True, default=str) (args.output_dir / 'train').mkdir(exist_ok=True) (args.output_dir / 'dev').mkdir(exist_ok=True) # generate dev data first dev_file = args.output_dir / 'dev' / f'epoch_0.shard{args.global_rank}.h5' generate_for_epoch(table_db, local_dev_table_indices, dev_file, input_formatter, args) for epoch in trange(args.epochs_to_generate, desc='Epoch'): gc.collect() epoch_filename = args.output_dir / 'train' / f"epoch_{epoch}.shard{args.global_rank}.h5" generate_for_epoch(table_db, local_train_table_indices, epoch_filename, input_formatter, args)