def add_args(cls, parser: ArgumentParser): TableBertConfig.add_args(parser) parser.add_argument("--num_vertical_attention_heads", type=int, default=6) parser.add_argument("--num_vertical_layers", type=int, default=3) parser.add_argument("--sample_row_num", type=int, default=3) parser.add_argument("--predict_cell_tokens", action='store_true', dest='predict_cell_tokens') parser.add_argument("--no_predict_cell_tokens", action='store_false', dest='predict_cell_tokens') parser.set_defaults(predict_cell_tokens=False) parser.add_argument("--initialize_from", type=Path, default=None)
def get_table_bert_model_deprecated(config: Dict, use_proxy=False, master=None): tb_path = config.get('table_bert_model_or_config') if tb_path is None or tb_path == '': tb_path = config.get('table_bert_config_file') if tb_path is None or tb_path == '': tb_path = config.get('table_bert_model') tb_path = Path(tb_path) assert tb_path.exists() if tb_path.suffix == '.json': tb_config_file = tb_path tb_path = None else: print(f'Loading table BERT model {tb_path}', file=sys.stderr) tb_config_file = tb_path.parent / 'tb_config.json' if use_proxy: from nsm.parser_module.table_bert_proxy import TableBertProxy tb_config = TableBertConfig.from_file(tb_config_file) table_bert_model = TableBertProxy(actor_id=master, table_bert_config=tb_config) else: table_bert_extra_config = config.get('table_bert_extra_config', dict()) # if it is a not pre-trained model, we use the default parameters if tb_path is None: table_bert_cls = TableBertConfig.infer_model_class_from_config_file(tb_config_file) print(f'Creating a default {table_bert_cls.__name__} without pre-trained parameters!', file=sys.stderr) table_bert_model = table_bert_cls( config=table_bert_cls.CONFIG_CLASS.from_file( tb_config_file, **table_bert_extra_config ) ) else: table_bert_model = TableBertModel.from_pretrained( tb_path, **table_bert_extra_config ) if type(table_bert_model) == VanillaTableBert: table_bert_model.config.column_representation = config.get('column_representation', 'mean_pool_column_name') print('Table Bert Config', file=sys.stderr) print(json.dumps(vars(table_bert_model.config), indent=2), file=sys.stderr) return table_bert_model
def __init__( self, num_vertical_attention_heads=6, num_vertical_layers=3, sample_row_num=3, table_mask_strategy='column', predict_cell_tokens=False, # vertical_layer_use_intermediate_transform=True, initialize_from=None, **kwargs, ): TableBertConfig.__init__(self, **kwargs) self.num_vertical_attention_heads = num_vertical_attention_heads self.num_vertical_layers = num_vertical_layers self.sample_row_num = sample_row_num self.table_mask_strategy = table_mask_strategy self.predict_cell_tokens = predict_cell_tokens # self.vertical_layer_use_intermediate_transform = vertical_layer_use_intermediate_transform self.initialize_from = initialize_from
def load(cls, model_path: Union[str, Path], config_file: Optional[Union[str, Path]] = None, **override_config: Dict): if model_path in ('bert-base-uncased', 'bert-large-uncased'): from table_bert.vanilla_table_bert import VanillaTableBert, TableBertConfig config = TableBertConfig(**override_config) model = VanillaTableBert(config) return model if model_path and isinstance(model_path, str): model_path = Path(model_path) if config_file is None: config_file = model_path.parent / 'tb_config.json' elif isinstance(config_file, str): config_file = Path(config_file) if model_path: state_dict = torch.load(str(model_path), map_location='cpu') else: state_dict = None config_dict = json.load(open(config_file)) if cls == TableBertModel: if 'num_vertical_attention_heads' in config_dict: from table_bert.vertical.vertical_attention_table_bert import VerticalAttentionTableBert, VerticalAttentionTableBertConfig table_bert_cls = VerticalAttentionTableBert config_cls = VerticalAttentionTableBertConfig else: from table_bert.vanilla_table_bert import VanillaTableBert from table_bert.config import TableBertConfig table_bert_cls = VanillaTableBert config_cls = TableBertConfig else: table_bert_cls = cls config_cls = table_bert_cls.CONFIG_CLASS config = config_cls.from_file(config_file, **override_config) model = table_bert_cls(config) # old table_bert format if state_dict is not None: # fix the name for weight `cls.predictions.decoder.bias`, # to make it compatible with the latest version of `transformers` from table_bert.utils import hf_flag if hf_flag == 'new': old_key_to_new_key_names: List[(str, str)] = [] for key in state_dict: if key.endswith('.predictions.bias'): old_key_to_new_key_names.append( (key, key.replace('.predictions.bias', '.predictions.decoder.bias'))) for old_key, new_key in old_key_to_new_key_names: state_dict[new_key] = state_dict[old_key] if not any(key.startswith('_bert_model') for key in state_dict): print('warning: loading model from an old version', file=sys.stderr) bert_model = BertForMaskedLM.from_pretrained( config.base_model_name, state_dict=state_dict) model._bert_model = bert_model else: model.load_state_dict(state_dict, strict=True) return model
def from_pretrained(cls, model_name_or_path: Optional[Union[str, Path]] = None, config_file: Optional[Union[str, Path]] = None, config: Optional[TableBertConfig] = None, state_dict: Optional[Dict] = None, **kwargs) -> 'TableBertModel': # Avoid cyclic import. # TODO: a better way to import these dependencies? from table_bert.vertical.vertical_attention_table_bert import ( VerticalAttentionTableBert, VerticalAttentionTableBertConfig) from table_bert.vanilla_table_bert import VanillaTableBert if model_name_or_path in {'bert-base-uncased', 'bert-large-uncased'}: config = TableBertConfig(base_model_name=model_name_or_path) overriding_config = config.extract_args(kwargs, pop=True) if len(overriding_config) > 0: config = config.with_new_args(**overriding_config) model = VanillaTableBert(config) return model if not isinstance(config, TableBertConfig): if config_file: config_file = Path(config_file) else: assert model_name_or_path, f'model path is None' config_file = Path( model_name_or_path).parent / 'tb_config.json' assert config_file.exists( ), f'Unable to find TaBERT config file at {config_file}' # Identify from the json config file whether the model uses vertical self-attention (TaBERT(K>1)) if cls == TableBertModel and VerticalAttentionTableBertConfig.is_valid_config_file( config_file): config_cls = VerticalAttentionTableBertConfig else: config_cls = TableBertConfig config = config_cls.from_file(config_file) overriding_config = config.extract_args(kwargs, pop=True) if len(overriding_config) > 0: config = config.with_new_args(**overriding_config) model_kwargs = kwargs model_cls = ( cls # If the current class is not the base generic class, then we assume the user want to # load a pre-trained instance of that specific model class. Otherwise, we infer the model # class from its config class if cls != TableBertModel else { TableBertConfig.__name__: VanillaTableBert, VerticalAttentionTableBertConfig.__name__: VerticalAttentionTableBert }[config.__class__.__name__]) model = model_cls(config, **model_kwargs) if state_dict is None: state_dict = torch.load(model_name_or_path, map_location="cpu") # fix the name for weight `cls.predictions.decoder.bias`, # to make it compatible with the latest version of HuggingFace `transformers` if TRANSFORMER_VERSION == TransformerVersion.TRANSFORMERS: old_key_to_new_key_names: List[(str, str)] = [] for key in state_dict: if key.endswith('.predictions.bias'): old_key_to_new_key_names.append( (key, key.replace('.predictions.bias', '.predictions.decoder.bias'))) for old_key, new_key in old_key_to_new_key_names: state_dict[new_key] = state_dict[old_key] # Problem: Missing key(s) in state_dict: "span_based_prediction.prediction.decoder.bias" #코드 확인 결과 prediction.bias와 prediction.decoder.bias는 값을 공유함 state_dict[ 'span_based_prediction.prediction.decoder.bias'] = state_dict[ 'span_based_prediction.prediction.bias'] # Problem: Mssing key(s) in state_dict: "_bert_model.bert.embeddings.position_ids" #코드 확인 결과 이전 버전에서는 저장하지 않았던 값으로 None으로 두거나 값을 비워둘 경우 Missing key Error가 나기 때문에 내부에서 초기화해서 만들어 주는 방법과 같은 방법으로 만들어줌 state_dict[ "_bert_model.bert.embeddings.position_ids"] = torch.arange( config.max_position_embeddings).expand(1, -1) model.load_state_dict(state_dict, strict=True) #model.load_state_dict(state_dict, strict=False) return model
info.update({ 'num_column_tokens_to_mask': num_column_tokens_to_mask, 'num_context_tokens_to_mask': num_context_tokens_to_mask, }) return tokens, masked_indices, masked_token_labels, info def remove_unecessary_instance_entries(self, instance: Dict): del instance['tokens'] del instance['masked_lm_labels'] del instance['info'] if __name__ == '__main__': config = TableBertConfig() tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') input_formatter = VanillaTableBertInputFormatter(config, tokenizer) header = [] for i in range(1000): header.append( Column( name='test', type='text', name_tokens=['test'] * 3, sample_value='ha ha ha yay', sample_value_tokens=['ha', 'ha', 'ha', 'yay'] ) )
def main(): parser = ArgumentParser() parser.add_argument('--train_corpus', type=Path, required=True) parser.add_argument("--output_dir", type=Path, required=True) parser.add_argument("--epochs_to_generate", type=int, default=3, help="Number of epochs of preprocess to pregenerate") parser.add_argument('--no_wiki_tables_from_common_crawl', action='store_true', default=False) parser.add_argument('--global_rank', type=int, default=os.environ.get('SLURM_PROCID', 0)) parser.add_argument('--world_size', type=int, default=os.environ.get('SLURM_NTASKS', 1)) ## YS parser.add_argument('--use_acoustic_confusion', action='store_true', default=False) parser.add_argument('--acoustic_confusion_prob', type=float, default=0.15, help="Probability of replacing a token with a confused one") # parser.add_argument('--acoustic_confusion_type', type=str, choices=['random', 'gpt2'], default='random') parser.add_argument('--word_confusion_path', type=Path, default='') TableBertConfig.add_args(parser) args = parser.parse_args() args.is_master = args.global_rank == 0 logger = logging.getLogger('DataGenerator') handler = logging.StreamHandler(sys.stderr) logger.addHandler(handler) logger.setLevel(logging.DEBUG) logger.info(f'Rank {args.global_rank} out of {args.world_size}') sys.stderr.flush() table_bert_config = TableBertConfig.from_dict(vars(args)) tokenizer = BertTokenizer.from_pretrained(table_bert_config.base_model_name) ## YS if args.use_acoustic_confusion: assert args.word_confusion_path != '' acoustic_confuser = SentenceAcousticConfuser_RandomReplace(args.word_confusion_path, default_p=args.acoustic_confusion_prob) input_formatter = VanillaTableBertInputFormatterWithConfusion(table_bert_config, tokenizer, acoustic_confuser) else: input_formatter = VanillaTableBertInputFormatter(table_bert_config, tokenizer) total_tables_num = int(subprocess.check_output(f"wc -l {args.train_corpus}", shell=True).split()[0]) dev_table_num = min(int(total_tables_num * 0.1), 100000) train_table_num = total_tables_num - dev_table_num # seed the RNG to make sure each process follows the same spliting rng = np.random.RandomState(seed=5783287) corpus_table_indices = list(range(total_tables_num)) rng.shuffle(corpus_table_indices) dev_table_indices = corpus_table_indices[:dev_table_num] train_table_indices = corpus_table_indices[dev_table_num:] local_dev_table_indices = dev_table_indices[args.global_rank::args.world_size] local_train_table_indices = train_table_indices[args.global_rank::args.world_size] local_indices = local_dev_table_indices + local_train_table_indices logger.info(f'total tables: {total_tables_num}') logger.debug(f'local dev table indices: {local_dev_table_indices[:1000]}') logger.debug(f'local train table indices: {local_train_table_indices[:1000]}') with TableDatabase.from_jsonl(args.train_corpus, backend='memory', tokenizer=tokenizer, indices=local_indices) as table_db: local_indices = {idx for idx in local_indices if idx in table_db} local_dev_table_indices = [idx for idx in local_dev_table_indices if idx in local_indices] local_train_table_indices = [idx for idx in local_train_table_indices if idx in local_indices] args.output_dir.mkdir(exist_ok=True, parents=True) print(f'Num tables to be processed by local worker: {len(table_db)}', file=sys.stdout) if args.is_master: with (args.output_dir / 'config.json').open('w') as f: json.dump(vars(args), f, indent=2, sort_keys=True, default=str) (args.output_dir / 'train').mkdir(exist_ok=True) (args.output_dir / 'dev').mkdir(exist_ok=True) # generate dev data first dev_file = args.output_dir / 'dev' / f'epoch_0.shard{args.global_rank}.h5' generate_for_epoch(table_db, local_dev_table_indices, dev_file, input_formatter, args) for epoch in trange(args.epochs_to_generate, desc='Epoch'): gc.collect() epoch_filename = args.output_dir / 'train' / f"epoch_{epoch}.shard{args.global_rank}.h5" generate_for_epoch(table_db, local_train_table_indices, epoch_filename, input_formatter, args)