Exemple #1
0
    def add_args(cls, parser: ArgumentParser):
        TableBertConfig.add_args(parser)

        parser.add_argument("--num_vertical_attention_heads",
                            type=int,
                            default=6)
        parser.add_argument("--num_vertical_layers", type=int, default=3)
        parser.add_argument("--sample_row_num", type=int, default=3)
        parser.add_argument("--predict_cell_tokens",
                            action='store_true',
                            dest='predict_cell_tokens')
        parser.add_argument("--no_predict_cell_tokens",
                            action='store_false',
                            dest='predict_cell_tokens')
        parser.set_defaults(predict_cell_tokens=False)

        parser.add_argument("--initialize_from", type=Path, default=None)
def get_table_bert_model_deprecated(config: Dict, use_proxy=False, master=None):
    tb_path = config.get('table_bert_model_or_config')
    if tb_path is None or tb_path == '':
        tb_path = config.get('table_bert_config_file')
    if tb_path is None or tb_path == '':
        tb_path = config.get('table_bert_model')

    tb_path = Path(tb_path)
    assert tb_path.exists()

    if tb_path.suffix == '.json':
        tb_config_file = tb_path
        tb_path = None
    else:
        print(f'Loading table BERT model {tb_path}', file=sys.stderr)
        tb_config_file = tb_path.parent / 'tb_config.json'

    if use_proxy:
        from nsm.parser_module.table_bert_proxy import TableBertProxy
        tb_config = TableBertConfig.from_file(tb_config_file)
        table_bert_model = TableBertProxy(actor_id=master, table_bert_config=tb_config)
    else:
        table_bert_extra_config = config.get('table_bert_extra_config', dict())
        # if it is a not pre-trained model, we use the default parameters
        if tb_path is None:
            table_bert_cls = TableBertConfig.infer_model_class_from_config_file(tb_config_file)
            print(f'Creating a default {table_bert_cls.__name__} without pre-trained parameters!', file=sys.stderr)

            table_bert_model = table_bert_cls(
                config=table_bert_cls.CONFIG_CLASS.from_file(
                    tb_config_file, **table_bert_extra_config
                )
            )
        else:
            table_bert_model = TableBertModel.from_pretrained(
                tb_path,
                **table_bert_extra_config
            )

        if type(table_bert_model) == VanillaTableBert:
            table_bert_model.config.column_representation = config.get('column_representation', 'mean_pool_column_name')

        print('Table Bert Config', file=sys.stderr)
        print(json.dumps(vars(table_bert_model.config), indent=2), file=sys.stderr)

    return table_bert_model
Exemple #3
0
    def __init__(
        self,
        num_vertical_attention_heads=6,
        num_vertical_layers=3,
        sample_row_num=3,
        table_mask_strategy='column',
        predict_cell_tokens=False,
        # vertical_layer_use_intermediate_transform=True,
        initialize_from=None,
        **kwargs,
    ):
        TableBertConfig.__init__(self, **kwargs)

        self.num_vertical_attention_heads = num_vertical_attention_heads
        self.num_vertical_layers = num_vertical_layers
        self.sample_row_num = sample_row_num
        self.table_mask_strategy = table_mask_strategy
        self.predict_cell_tokens = predict_cell_tokens
        # self.vertical_layer_use_intermediate_transform = vertical_layer_use_intermediate_transform
        self.initialize_from = initialize_from
Exemple #4
0
    def load(cls,
             model_path: Union[str, Path],
             config_file: Optional[Union[str, Path]] = None,
             **override_config: Dict):
        if model_path in ('bert-base-uncased', 'bert-large-uncased'):
            from table_bert.vanilla_table_bert import VanillaTableBert, TableBertConfig
            config = TableBertConfig(**override_config)
            model = VanillaTableBert(config)

            return model

        if model_path and isinstance(model_path, str):
            model_path = Path(model_path)

        if config_file is None:
            config_file = model_path.parent / 'tb_config.json'
        elif isinstance(config_file, str):
            config_file = Path(config_file)

        if model_path:
            state_dict = torch.load(str(model_path), map_location='cpu')
        else:
            state_dict = None

        config_dict = json.load(open(config_file))

        if cls == TableBertModel:
            if 'num_vertical_attention_heads' in config_dict:
                from table_bert.vertical.vertical_attention_table_bert import VerticalAttentionTableBert, VerticalAttentionTableBertConfig
                table_bert_cls = VerticalAttentionTableBert
                config_cls = VerticalAttentionTableBertConfig
            else:
                from table_bert.vanilla_table_bert import VanillaTableBert
                from table_bert.config import TableBertConfig
                table_bert_cls = VanillaTableBert
                config_cls = TableBertConfig
        else:
            table_bert_cls = cls
            config_cls = table_bert_cls.CONFIG_CLASS

        config = config_cls.from_file(config_file, **override_config)
        model = table_bert_cls(config)

        # old table_bert format
        if state_dict is not None:
            # fix the name for weight `cls.predictions.decoder.bias`,
            # to make it compatible with the latest version of `transformers`

            from table_bert.utils import hf_flag
            if hf_flag == 'new':
                old_key_to_new_key_names: List[(str, str)] = []
                for key in state_dict:
                    if key.endswith('.predictions.bias'):
                        old_key_to_new_key_names.append(
                            (key,
                             key.replace('.predictions.bias',
                                         '.predictions.decoder.bias')))

                for old_key, new_key in old_key_to_new_key_names:
                    state_dict[new_key] = state_dict[old_key]

            if not any(key.startswith('_bert_model') for key in state_dict):
                print('warning: loading model from an old version',
                      file=sys.stderr)
                bert_model = BertForMaskedLM.from_pretrained(
                    config.base_model_name, state_dict=state_dict)
                model._bert_model = bert_model
            else:
                model.load_state_dict(state_dict, strict=True)

        return model
Exemple #5
0
    def from_pretrained(cls,
                        model_name_or_path: Optional[Union[str, Path]] = None,
                        config_file: Optional[Union[str, Path]] = None,
                        config: Optional[TableBertConfig] = None,
                        state_dict: Optional[Dict] = None,
                        **kwargs) -> 'TableBertModel':
        # Avoid cyclic import.
        # TODO: a better way to import these dependencies?
        from table_bert.vertical.vertical_attention_table_bert import (
            VerticalAttentionTableBert, VerticalAttentionTableBertConfig)
        from table_bert.vanilla_table_bert import VanillaTableBert

        if model_name_or_path in {'bert-base-uncased', 'bert-large-uncased'}:
            config = TableBertConfig(base_model_name=model_name_or_path)
            overriding_config = config.extract_args(kwargs, pop=True)
            if len(overriding_config) > 0:
                config = config.with_new_args(**overriding_config)

            model = VanillaTableBert(config)

            return model

        if not isinstance(config, TableBertConfig):
            if config_file:
                config_file = Path(config_file)
            else:
                assert model_name_or_path, f'model path is None'
                config_file = Path(
                    model_name_or_path).parent / 'tb_config.json'

            assert config_file.exists(
            ), f'Unable to find TaBERT config file at {config_file}'

            # Identify from the json config file whether the model uses vertical self-attention (TaBERT(K>1))
            if cls == TableBertModel and VerticalAttentionTableBertConfig.is_valid_config_file(
                    config_file):
                config_cls = VerticalAttentionTableBertConfig
            else:
                config_cls = TableBertConfig

            config = config_cls.from_file(config_file)

        overriding_config = config.extract_args(kwargs, pop=True)
        if len(overriding_config) > 0:
            config = config.with_new_args(**overriding_config)

        model_kwargs = kwargs

        model_cls = (
            cls  # If the current class is not the base generic class, then we assume the user want to
            # load a pre-trained instance of that specific model class. Otherwise, we infer the model
            # class from its config class
            if cls != TableBertModel else {
                TableBertConfig.__name__:
                VanillaTableBert,
                VerticalAttentionTableBertConfig.__name__:
                VerticalAttentionTableBert
            }[config.__class__.__name__])

        model = model_cls(config, **model_kwargs)

        if state_dict is None:
            state_dict = torch.load(model_name_or_path, map_location="cpu")

        # fix the name for weight `cls.predictions.decoder.bias`,
        # to make it compatible with the latest version of HuggingFace `transformers`
        if TRANSFORMER_VERSION == TransformerVersion.TRANSFORMERS:
            old_key_to_new_key_names: List[(str, str)] = []
            for key in state_dict:
                if key.endswith('.predictions.bias'):
                    old_key_to_new_key_names.append(
                        (key,
                         key.replace('.predictions.bias',
                                     '.predictions.decoder.bias')))

            for old_key, new_key in old_key_to_new_key_names:
                state_dict[new_key] = state_dict[old_key]

            # Problem: Missing key(s) in state_dict: "span_based_prediction.prediction.decoder.bias"
            #코드 확인 결과  prediction.bias와 prediction.decoder.bias는 값을 공유함
            state_dict[
                'span_based_prediction.prediction.decoder.bias'] = state_dict[
                    'span_based_prediction.prediction.bias']

            # Problem: Mssing key(s) in state_dict: "_bert_model.bert.embeddings.position_ids"
            #코드 확인 결과 이전 버전에서는 저장하지 않았던 값으로 None으로 두거나 값을 비워둘 경우 Missing key Error가 나기 때문에 내부에서 초기화해서 만들어 주는 방법과 같은 방법으로 만들어줌
            state_dict[
                "_bert_model.bert.embeddings.position_ids"] = torch.arange(
                    config.max_position_embeddings).expand(1, -1)

        model.load_state_dict(state_dict, strict=True)
        #model.load_state_dict(state_dict, strict=False)

        return model
Exemple #6
0
        info.update({
            'num_column_tokens_to_mask': num_column_tokens_to_mask,
            'num_context_tokens_to_mask': num_context_tokens_to_mask,
        })

        return tokens, masked_indices, masked_token_labels, info

    def remove_unecessary_instance_entries(self, instance: Dict):
        del instance['tokens']
        del instance['masked_lm_labels']
        del instance['info']


if __name__ == '__main__':
    config = TableBertConfig()
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    input_formatter = VanillaTableBertInputFormatter(config, tokenizer)

    header = []
    for i in range(1000):
        header.append(
            Column(
                name='test',
                type='text',
                name_tokens=['test'] * 3,
                sample_value='ha ha ha yay',
                sample_value_tokens=['ha', 'ha', 'ha', 'yay']
            )
        )
Exemple #7
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--train_corpus', type=Path, required=True)
    parser.add_argument("--output_dir", type=Path, required=True)
    parser.add_argument("--epochs_to_generate", type=int, default=3,
                        help="Number of epochs of preprocess to pregenerate")
    parser.add_argument('--no_wiki_tables_from_common_crawl', action='store_true', default=False)
    parser.add_argument('--global_rank', type=int, default=os.environ.get('SLURM_PROCID', 0))
    parser.add_argument('--world_size', type=int, default=os.environ.get('SLURM_NTASKS', 1))

    ## YS
    parser.add_argument('--use_acoustic_confusion', action='store_true', default=False)
    parser.add_argument('--acoustic_confusion_prob', type=float, default=0.15,
                        help="Probability of replacing a token with a confused one")
    # parser.add_argument('--acoustic_confusion_type', type=str, choices=['random', 'gpt2'], default='random')
    parser.add_argument('--word_confusion_path', type=Path, default='')
    
    TableBertConfig.add_args(parser)

    args = parser.parse_args()
    args.is_master = args.global_rank == 0

    logger = logging.getLogger('DataGenerator')
    handler = logging.StreamHandler(sys.stderr)
    logger.addHandler(handler)
    logger.setLevel(logging.DEBUG)

    logger.info(f'Rank {args.global_rank} out of {args.world_size}')
    sys.stderr.flush()

    table_bert_config = TableBertConfig.from_dict(vars(args))
    tokenizer = BertTokenizer.from_pretrained(table_bert_config.base_model_name)

    ## YS
    if args.use_acoustic_confusion:
        assert args.word_confusion_path != ''

        acoustic_confuser = SentenceAcousticConfuser_RandomReplace(args.word_confusion_path, default_p=args.acoustic_confusion_prob)
        input_formatter = VanillaTableBertInputFormatterWithConfusion(table_bert_config, tokenizer, acoustic_confuser)
    else:
        input_formatter = VanillaTableBertInputFormatter(table_bert_config, tokenizer)

    total_tables_num = int(subprocess.check_output(f"wc -l {args.train_corpus}", shell=True).split()[0])
    dev_table_num = min(int(total_tables_num * 0.1), 100000)
    train_table_num = total_tables_num - dev_table_num

    # seed the RNG to make sure each process follows the same spliting
    rng = np.random.RandomState(seed=5783287)

    corpus_table_indices = list(range(total_tables_num))
    rng.shuffle(corpus_table_indices)
    dev_table_indices = corpus_table_indices[:dev_table_num]
    train_table_indices = corpus_table_indices[dev_table_num:]

    local_dev_table_indices = dev_table_indices[args.global_rank::args.world_size]
    local_train_table_indices = train_table_indices[args.global_rank::args.world_size]
    local_indices = local_dev_table_indices + local_train_table_indices

    logger.info(f'total tables: {total_tables_num}')
    logger.debug(f'local dev table indices: {local_dev_table_indices[:1000]}')
    logger.debug(f'local train table indices: {local_train_table_indices[:1000]}')

    with TableDatabase.from_jsonl(args.train_corpus, backend='memory', tokenizer=tokenizer, indices=local_indices) as table_db:
        local_indices = {idx for idx in local_indices if idx in table_db}
        local_dev_table_indices = [idx for idx in local_dev_table_indices if idx in local_indices]
        local_train_table_indices = [idx for idx in local_train_table_indices if idx in local_indices]

        args.output_dir.mkdir(exist_ok=True, parents=True)
        print(f'Num tables to be processed by local worker: {len(table_db)}', file=sys.stdout)

        if args.is_master:
            with (args.output_dir / 'config.json').open('w') as f:
                json.dump(vars(args), f, indent=2, sort_keys=True, default=str)

        (args.output_dir / 'train').mkdir(exist_ok=True)
        (args.output_dir / 'dev').mkdir(exist_ok=True)

        # generate dev data first
        dev_file = args.output_dir / 'dev' / f'epoch_0.shard{args.global_rank}.h5'
        generate_for_epoch(table_db, local_dev_table_indices, dev_file, input_formatter, args)

        for epoch in trange(args.epochs_to_generate, desc='Epoch'):
            gc.collect()
            epoch_filename = args.output_dir / 'train' / f"epoch_{epoch}.shard{args.global_rank}.h5"
            generate_for_epoch(table_db, local_train_table_indices, epoch_filename, input_formatter, args)