def tokenize(self, tokenizer: BertTokenizer): for column in self.header: column.name_tokens = tokenizer.tokenize(column.name) if column.sample_value is not None: column.sample_value_tokens = tokenizer.tokenize( str(column.sample_value)) tokenized_rows = [ {k: tokenizer.tokenize(str(v)) for k, v in row.items()} if isinstance( row, dict) else [tokenizer.tokenize(str(v)) for v in row] for row in self.data ] self.data = tokenized_rows setattr(self, 'tokenized', True) return self
def __init__(self, base_model_name: str = 'bert-base-uncased', column_delimiter: str = '[SEP]', context_first: bool = True, cell_input_template: str = 'column | type | value', column_representation: str = 'mean_pool', max_cell_len: int = 5, max_sequence_len: int = 512, max_context_len: int = 256, masked_context_prob: float = 0.15, masked_column_prob: float = 0.2, max_predictions_per_seq: int = 100, context_sample_strategy: str = 'nearest', table_mask_strategy: str = 'column', do_lower_case: bool = True, **kwargs): super(TableBertConfig, self).__init__() self.base_model_name = base_model_name self.column_delimiter = column_delimiter self.context_first = context_first self.column_representation = column_representation self.max_cell_len = max_cell_len self.max_sequence_len = max_sequence_len self.max_context_len = max_context_len self.do_lower_case = do_lower_case # tokenizer = BertTokenizer.from_pretrained(self.base_model_name) if isinstance(cell_input_template, str): if ' ' in cell_input_template: cell_input_template = cell_input_template.split(' ') else: print( f'WARNING: cell_input_template is outdated: {cell_input_template}', file=sys.stderr) cell_input_template = BertTokenizer.from_pretrained( self.base_model_name).tokenize(cell_input_template) self.cell_input_template = cell_input_template self.masked_context_prob = masked_context_prob self.masked_column_prob = masked_column_prob self.max_predictions_per_seq = max_predictions_per_seq self.context_sample_strategy = context_sample_strategy self.table_mask_strategy = table_mask_strategy if not hasattr(self, 'vocab_size_or_config_json_file'): bert_config = BERT_CONFIGS[self.base_model_name] for k, v in vars(bert_config).items(): setattr(self, k, v)
def __init__(self, config: TableBertConfig, **kwargs): nn.Module.__init__(self) bert_model: Union[BertForPreTraining, BertModel] = kwargs.pop('bert_model', None) if bert_model is not None: logging.warning( 'using `bert_model` to initialize `TableBertModel` is deprecated. ' 'I will still set `self._bert_model` this time.') self._bert_model = bert_model self.tokenizer = BertTokenizer.from_pretrained(config.base_model_name) self.config = config
info.update({ 'num_column_tokens_to_mask': num_column_tokens_to_mask, 'num_context_tokens_to_mask': num_context_tokens_to_mask, }) return tokens, masked_indices, masked_token_labels, info def remove_unecessary_instance_entries(self, instance: Dict): del instance['tokens'] del instance['masked_lm_labels'] del instance['info'] if __name__ == '__main__': config = TableBertConfig() tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') input_formatter = VanillaTableBertInputFormatter(config, tokenizer) header = [] for i in range(1000): header.append( Column( name='test', type='text', name_tokens=['test'] * 3, sample_value='ha ha ha yay', sample_value_tokens=['ha', 'ha', 'ha', 'yay'] ) ) print(
def main(): parser = ArgumentParser() parser.add_argument('--train_corpus', type=Path, required=True) parser.add_argument("--output_dir", type=Path, required=True) parser.add_argument("--epochs_to_generate", type=int, default=3, help="Number of epochs of preprocess to pregenerate") parser.add_argument('--no_wiki_tables_from_common_crawl', action='store_true', default=False) parser.add_argument('--global_rank', type=int, default=os.environ.get('SLURM_PROCID', 0)) parser.add_argument('--world_size', type=int, default=os.environ.get('SLURM_NTASKS', 1)) ## YS parser.add_argument('--use_acoustic_confusion', action='store_true', default=False) parser.add_argument('--acoustic_confusion_prob', type=float, default=0.15, help="Probability of replacing a token with a confused one") # parser.add_argument('--acoustic_confusion_type', type=str, choices=['random', 'gpt2'], default='random') parser.add_argument('--word_confusion_path', type=Path, default='') TableBertConfig.add_args(parser) args = parser.parse_args() args.is_master = args.global_rank == 0 logger = logging.getLogger('DataGenerator') handler = logging.StreamHandler(sys.stderr) logger.addHandler(handler) logger.setLevel(logging.DEBUG) logger.info(f'Rank {args.global_rank} out of {args.world_size}') sys.stderr.flush() table_bert_config = TableBertConfig.from_dict(vars(args)) tokenizer = BertTokenizer.from_pretrained(table_bert_config.base_model_name) ## YS if args.use_acoustic_confusion: assert args.word_confusion_path != '' acoustic_confuser = SentenceAcousticConfuser_RandomReplace(args.word_confusion_path, default_p=args.acoustic_confusion_prob) input_formatter = VanillaTableBertInputFormatterWithConfusion(table_bert_config, tokenizer, acoustic_confuser) else: input_formatter = VanillaTableBertInputFormatter(table_bert_config, tokenizer) total_tables_num = int(subprocess.check_output(f"wc -l {args.train_corpus}", shell=True).split()[0]) dev_table_num = min(int(total_tables_num * 0.1), 100000) train_table_num = total_tables_num - dev_table_num # seed the RNG to make sure each process follows the same spliting rng = np.random.RandomState(seed=5783287) corpus_table_indices = list(range(total_tables_num)) rng.shuffle(corpus_table_indices) dev_table_indices = corpus_table_indices[:dev_table_num] train_table_indices = corpus_table_indices[dev_table_num:] local_dev_table_indices = dev_table_indices[args.global_rank::args.world_size] local_train_table_indices = train_table_indices[args.global_rank::args.world_size] local_indices = local_dev_table_indices + local_train_table_indices logger.info(f'total tables: {total_tables_num}') logger.debug(f'local dev table indices: {local_dev_table_indices[:1000]}') logger.debug(f'local train table indices: {local_train_table_indices[:1000]}') with TableDatabase.from_jsonl(args.train_corpus, backend='memory', tokenizer=tokenizer, indices=local_indices) as table_db: local_indices = {idx for idx in local_indices if idx in table_db} local_dev_table_indices = [idx for idx in local_dev_table_indices if idx in local_indices] local_train_table_indices = [idx for idx in local_train_table_indices if idx in local_indices] args.output_dir.mkdir(exist_ok=True, parents=True) print(f'Num tables to be processed by local worker: {len(table_db)}', file=sys.stdout) if args.is_master: with (args.output_dir / 'config.json').open('w') as f: json.dump(vars(args), f, indent=2, sort_keys=True, default=str) (args.output_dir / 'train').mkdir(exist_ok=True) (args.output_dir / 'dev').mkdir(exist_ok=True) # generate dev data first dev_file = args.output_dir / 'dev' / f'epoch_0.shard{args.global_rank}.h5' generate_for_epoch(table_db, local_dev_table_indices, dev_file, input_formatter, args) for epoch in trange(args.epochs_to_generate, desc='Epoch'): gc.collect() epoch_filename = args.output_dir / 'train' / f"epoch_{epoch}.shard{args.global_rank}.h5" generate_for_epoch(table_db, local_train_table_indices, epoch_filename, input_formatter, args)