def __init__(self, config: VerticalAttentionTableBertConfig, **kwargs):
        super(VanillaTableBert, self).__init__(config, **kwargs)

        self._bert_model = BertForMaskedLM.from_pretrained(
            config.base_model_name)

        self.input_formatter = VerticalAttentionTableBertInputFormatter(
            self.config, self.tokenizer)

        if config.predict_cell_tokens:
            self.span_based_prediction = SpanBasedPrediction(
                config, self._bert_model.cls.predictions)

        self.vertical_embedding_layer = VerticalEmbeddingLayer()
        self.vertical_transformer_layers = nn.ModuleList([
            BertVerticalLayer(self.config)
            for _ in range(self.config.num_vertical_layers)
        ])

        if config.initialize_from:
            print(f'Loading initial parameters from {config.initialize_from}',
                  file=sys.stderr)
            initial_state_dict = torch.load(config.initialize_from,
                                            map_location='cpu')
            if not any(
                    key.startswith('_bert_model')
                    for key in initial_state_dict):
                print('warning: loading model from an old version',
                      file=sys.stderr)
                bert_model = BertForMaskedLM.from_pretrained(
                    config.base_model_name, state_dict=initial_state_dict)
                self._bert_model = bert_model
            else:
                load_result = self.load_state_dict(initial_state_dict,
                                                   strict=False)
                if load_result.missing_keys:
                    print(f'warning: missing keys: {load_result.missing_keys}',
                          file=sys.stderr)
                if load_result.unexpected_keys:
                    print(
                        f'warning: unexpected keys: {load_result.unexpected_keys}',
                        file=sys.stderr)

        added_modules = [
            self.vertical_embedding_layer, self.vertical_transformer_layers
        ]
        if config.predict_cell_tokens:
            added_modules.extend([
                self.span_based_prediction.dense1,
                self.span_based_prediction.dense2,
                self.span_based_prediction.layer_norm1,
                self.span_based_prediction.layer_norm2
            ])

        for module in added_modules:
            if TRANSFORMER_VERSION == TransformerVersion.TRANSFORMERS:
                module.apply(self._bert_model._init_weights)
            else:
                module.apply(self._bert_model.init_bert_weights)
Esempio n. 2
0
    def __init__(self, config: TableBertConfig, **kwargs):
        super(VanillaTableBert, self).__init__(config, **kwargs)

        self._bert_model = BertForMaskedLM.from_pretrained(
            config.base_model_name)
        self.input_formatter = VanillaTableBertInputFormatter(
            self.config, self.tokenizer)
Esempio n. 3
0
    def load(cls,
             model_path: Union[str, Path],
             config_file: Optional[Union[str, Path]] = None,
             **override_config: Dict):
        if model_path in ('bert-base-uncased', 'bert-large-uncased'):
            from table_bert.vanilla_table_bert import VanillaTableBert, TableBertConfig
            config = TableBertConfig(**override_config)
            model = VanillaTableBert(config)

            return model

        if model_path and isinstance(model_path, str):
            model_path = Path(model_path)

        if config_file is None:
            config_file = model_path.parent / 'tb_config.json'
        elif isinstance(config_file, str):
            config_file = Path(config_file)

        if model_path:
            state_dict = torch.load(str(model_path), map_location='cpu')
        else:
            state_dict = None

        config_dict = json.load(open(config_file))

        if cls == TableBertModel:
            if 'num_vertical_attention_heads' in config_dict:
                from table_bert.vertical.vertical_attention_table_bert import VerticalAttentionTableBert, VerticalAttentionTableBertConfig
                table_bert_cls = VerticalAttentionTableBert
                config_cls = VerticalAttentionTableBertConfig
            else:
                from table_bert.vanilla_table_bert import VanillaTableBert
                from table_bert.config import TableBertConfig
                table_bert_cls = VanillaTableBert
                config_cls = TableBertConfig
        else:
            table_bert_cls = cls
            config_cls = table_bert_cls.CONFIG_CLASS

        config = config_cls.from_file(config_file, **override_config)
        model = table_bert_cls(config)

        # old table_bert format
        if state_dict is not None:
            # fix the name for weight `cls.predictions.decoder.bias`,
            # to make it compatible with the latest version of `transformers`

            from table_bert.utils import hf_flag
            if hf_flag == 'new':
                old_key_to_new_key_names: List[(str, str)] = []
                for key in state_dict:
                    if key.endswith('.predictions.bias'):
                        old_key_to_new_key_names.append(
                            (key,
                             key.replace('.predictions.bias',
                                         '.predictions.decoder.bias')))

                for old_key, new_key in old_key_to_new_key_names:
                    state_dict[new_key] = state_dict[old_key]

            if not any(key.startswith('_bert_model') for key in state_dict):
                print('warning: loading model from an old version',
                      file=sys.stderr)
                bert_model = BertForMaskedLM.from_pretrained(
                    config.base_model_name, state_dict=state_dict)
                model._bert_model = bert_model
            else:
                model.load_state_dict(state_dict, strict=True)

        return model