示例#1
0
def test_cdssm_padding(train_raw):
    preprocessor = preprocessors.CDSSMPreprocessor()
    data_preprocessed = preprocessor.fit_transform(train_raw, verbose=0)
    dataset = Dataset(data_preprocessed, mode='point')

    pre_fixed_padding = callbacks.CDSSMPadding(fixed_length_left=5,
                                               fixed_length_right=5,
                                               pad_mode='pre')
    dataloader = DataLoader(dataset, batch_size=5, callback=pre_fixed_padding)
    for batch in dataloader:
        vocab_size = preprocessor.context['vocab_size']
        assert batch[0]['text_left'].shape == (5, 5, vocab_size)
        assert batch[0]['text_right'].shape == (5, 5, vocab_size)

    post_padding = callbacks.CDSSMPadding(pad_mode='post')
    dataloader = DataLoader(dataset, batch_size=5, callback=post_padding)
    for batch in dataloader:
        max_left_len = max(batch[0]['length_left'].numpy())
        max_right_len = max(batch[0]['length_right'].numpy())
        vocab_size = preprocessor.context['vocab_size']
        assert batch[0]['text_left'].shape == (5, max_left_len, vocab_size)
        assert batch[0]['text_right'].shape == (5, max_right_len, vocab_size)
示例#2
0
 def get_default_preprocessor(cls):
     """:return: Default preprocessor."""
     return preprocessors.CDSSMPreprocessor()