Exemple #1
0
    def __init__(self,
                 base_model_name: str = 'bert-base-uncased',
                 column_delimiter: str = '[SEP]',
                 context_first: bool = True,
                 cell_input_template: str = 'column | type | value',
                 column_representation: str = 'mean_pool',
                 max_cell_len: int = 5,
                 max_sequence_len: int = 512,
                 max_context_len: int = 256,
                 masked_context_prob: float = 0.15,
                 masked_column_prob: float = 0.2,
                 max_predictions_per_seq: int = 100,
                 context_sample_strategy: str = 'nearest',
                 table_mask_strategy: str = 'column',
                 do_lower_case: bool = True,
                 **kwargs):
        super(TableBertConfig, self).__init__()

        self.base_model_name = base_model_name
        self.column_delimiter = column_delimiter
        self.context_first = context_first
        self.column_representation = column_representation

        self.max_cell_len = max_cell_len
        self.max_sequence_len = max_sequence_len
        self.max_context_len = max_context_len

        self.do_lower_case = do_lower_case

        # tokenizer = BertTokenizer.from_pretrained(self.base_model_name)
        if isinstance(cell_input_template, str):
            if ' ' in cell_input_template:
                cell_input_template = cell_input_template.split(' ')
            else:
                print(
                    f'WARNING: cell_input_template is outdated: {cell_input_template}',
                    file=sys.stderr)
                cell_input_template = BertTokenizer.from_pretrained(
                    self.base_model_name).tokenize(cell_input_template)

        self.cell_input_template = cell_input_template

        self.masked_context_prob = masked_context_prob
        self.masked_column_prob = masked_column_prob
        self.max_predictions_per_seq = max_predictions_per_seq
        self.context_sample_strategy = context_sample_strategy
        self.table_mask_strategy = table_mask_strategy

        if not hasattr(self, 'vocab_size_or_config_json_file'):
            bert_config = BERT_CONFIGS[self.base_model_name]
            for k, v in vars(bert_config).items():
                setattr(self, k, v)
Exemple #2
0
    def __init__(self, config: TableBertConfig, **kwargs):
        nn.Module.__init__(self)

        bert_model: Union[BertForPreTraining,
                          BertModel] = kwargs.pop('bert_model', None)

        if bert_model is not None:
            logging.warning(
                'using `bert_model` to initialize `TableBertModel` is deprecated. '
                'I will still set `self._bert_model` this time.')

        self._bert_model = bert_model
        self.tokenizer = BertTokenizer.from_pretrained(config.base_model_name)
        self.config = config
Exemple #3
0
        info.update({
            'num_column_tokens_to_mask': num_column_tokens_to_mask,
            'num_context_tokens_to_mask': num_context_tokens_to_mask,
        })

        return tokens, masked_indices, masked_token_labels, info

    def remove_unecessary_instance_entries(self, instance: Dict):
        del instance['tokens']
        del instance['masked_lm_labels']
        del instance['info']


if __name__ == '__main__':
    config = TableBertConfig()
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    input_formatter = VanillaTableBertInputFormatter(config, tokenizer)

    header = []
    for i in range(1000):
        header.append(
            Column(
                name='test',
                type='text',
                name_tokens=['test'] * 3,
                sample_value='ha ha ha yay',
                sample_value_tokens=['ha', 'ha', 'ha', 'yay']
            )
        )

    print(
Exemple #4
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--train_corpus', type=Path, required=True)
    parser.add_argument("--output_dir", type=Path, required=True)
    parser.add_argument("--epochs_to_generate", type=int, default=3,
                        help="Number of epochs of preprocess to pregenerate")
    parser.add_argument('--no_wiki_tables_from_common_crawl', action='store_true', default=False)
    parser.add_argument('--global_rank', type=int, default=os.environ.get('SLURM_PROCID', 0))
    parser.add_argument('--world_size', type=int, default=os.environ.get('SLURM_NTASKS', 1))

    ## YS
    parser.add_argument('--use_acoustic_confusion', action='store_true', default=False)
    parser.add_argument('--acoustic_confusion_prob', type=float, default=0.15,
                        help="Probability of replacing a token with a confused one")
    # parser.add_argument('--acoustic_confusion_type', type=str, choices=['random', 'gpt2'], default='random')
    parser.add_argument('--word_confusion_path', type=Path, default='')
    
    TableBertConfig.add_args(parser)

    args = parser.parse_args()
    args.is_master = args.global_rank == 0

    logger = logging.getLogger('DataGenerator')
    handler = logging.StreamHandler(sys.stderr)
    logger.addHandler(handler)
    logger.setLevel(logging.DEBUG)

    logger.info(f'Rank {args.global_rank} out of {args.world_size}')
    sys.stderr.flush()

    table_bert_config = TableBertConfig.from_dict(vars(args))
    tokenizer = BertTokenizer.from_pretrained(table_bert_config.base_model_name)

    ## YS
    if args.use_acoustic_confusion:
        assert args.word_confusion_path != ''

        acoustic_confuser = SentenceAcousticConfuser_RandomReplace(args.word_confusion_path, default_p=args.acoustic_confusion_prob)
        input_formatter = VanillaTableBertInputFormatterWithConfusion(table_bert_config, tokenizer, acoustic_confuser)
    else:
        input_formatter = VanillaTableBertInputFormatter(table_bert_config, tokenizer)

    total_tables_num = int(subprocess.check_output(f"wc -l {args.train_corpus}", shell=True).split()[0])
    dev_table_num = min(int(total_tables_num * 0.1), 100000)
    train_table_num = total_tables_num - dev_table_num

    # seed the RNG to make sure each process follows the same spliting
    rng = np.random.RandomState(seed=5783287)

    corpus_table_indices = list(range(total_tables_num))
    rng.shuffle(corpus_table_indices)
    dev_table_indices = corpus_table_indices[:dev_table_num]
    train_table_indices = corpus_table_indices[dev_table_num:]

    local_dev_table_indices = dev_table_indices[args.global_rank::args.world_size]
    local_train_table_indices = train_table_indices[args.global_rank::args.world_size]
    local_indices = local_dev_table_indices + local_train_table_indices

    logger.info(f'total tables: {total_tables_num}')
    logger.debug(f'local dev table indices: {local_dev_table_indices[:1000]}')
    logger.debug(f'local train table indices: {local_train_table_indices[:1000]}')

    with TableDatabase.from_jsonl(args.train_corpus, backend='memory', tokenizer=tokenizer, indices=local_indices) as table_db:
        local_indices = {idx for idx in local_indices if idx in table_db}
        local_dev_table_indices = [idx for idx in local_dev_table_indices if idx in local_indices]
        local_train_table_indices = [idx for idx in local_train_table_indices if idx in local_indices]

        args.output_dir.mkdir(exist_ok=True, parents=True)
        print(f'Num tables to be processed by local worker: {len(table_db)}', file=sys.stdout)

        if args.is_master:
            with (args.output_dir / 'config.json').open('w') as f:
                json.dump(vars(args), f, indent=2, sort_keys=True, default=str)

        (args.output_dir / 'train').mkdir(exist_ok=True)
        (args.output_dir / 'dev').mkdir(exist_ok=True)

        # generate dev data first
        dev_file = args.output_dir / 'dev' / f'epoch_0.shard{args.global_rank}.h5'
        generate_for_epoch(table_db, local_dev_table_indices, dev_file, input_formatter, args)

        for epoch in trange(args.epochs_to_generate, desc='Epoch'):
            gc.collect()
            epoch_filename = args.output_dir / 'train' / f"epoch_{epoch}.shard{args.global_rank}.h5"
            generate_for_epoch(table_db, local_train_table_indices, epoch_filename, input_formatter, args)