Пример #1
0
 def get_default_padding_callback(cls,
                                  fixed_length_left: int = None,
                                  fixed_length_right: int = None,
                                  pad_value: typing.Union[int, str] = 0,
                                  pad_mode: str = 'post'):
     """:return: Default padding callback."""
     return callbacks.BertPadding(fixed_length_left=fixed_length_left,
                                  fixed_length_right=fixed_length_right,
                                  pad_value=pad_value,
                                  pad_mode=pad_mode)
Пример #2
0
def test_bert_padding(train_raw):
    preprocessor = preprocessors.BertPreprocessor()
    data_preprocessed = preprocessor.transform(train_raw, verbose=0)
    dataset = Dataset(data_preprocessed, mode='point')

    pre_fixed_padding = callbacks.BertPadding(fixed_length_left=5,
                                              fixed_length_right=5,
                                              pad_mode='pre')
    dataloader = DataLoader(dataset, batch_size=5, callback=pre_fixed_padding)
    for batch in dataloader:
        assert batch[0]['text_left'].shape == (5, 7)
        assert batch[0]['text_right'].shape == (5, 6)

    post_padding = callbacks.BertPadding(pad_mode='post')
    dataloader = DataLoader(dataset, batch_size=5, callback=post_padding)
    for batch in dataloader:
        max_left_len = max(batch[0]['length_left'].detach().cpu().numpy())
        max_right_len = max(batch[0]['length_right'].detach().cpu().numpy())
        assert batch[0]['text_left'].shape == (5, max_left_len + 2)
        assert batch[0]['text_right'].shape == (5, max_right_len + 1)