def test_decimal_friendly_collate_input_has_decimals_in_tuple():
    input_batch = ([Decimal('1.0'), 1], [Decimal('1.1'), 2])
    desired = [(Decimal('1.0'), Decimal('1.1')), (1, 2)]
    actual = decimal_friendly_collate(input_batch)

    assert len(actual) == 2
    assert desired[0] == actual[0]
    np.testing.assert_equal(desired[1], actual[1].numpy())
示例#2
0
def test_decimal_friendly_collate_input_has_decimals_in_dictionary():
    desired = {
        'decimal': [Decimal('1.0'), Decimal('1.1')],
        'int': [1, 2]
    }
    input_batch = [
        {'decimal': Decimal('1.0'), 'int': 1},
        {'decimal': Decimal('1.1'), 'int': 2},
    ]
    actual = decimal_friendly_collate(input_batch)

    assert len(actual) == 2
    assert desired['decimal'] == actual['decimal']
    np.testing.assert_equal(desired['int'], actual['int'].numpy())
示例#3
0
 def collate_fn(batch_list: List[Dict]):
     batch = decimal_friendly_collate(batch_list)
     preprocessed_batch = batch_preprocessor(batch)
     if use_gpu:
         preprocessed_batch = preprocessed_batch.cuda()
     return preprocessed_batch
def test_decimal_friendly_collate_empty_input():
    assert decimal_friendly_collate([dict()]) == dict()
示例#5
0
文件: utils.py 项目: jayhsieh/ReAgent
 def collate_fn(batch_list: List[Dict]):
     batch = decimal_friendly_collate(batch_list)
     return batch_preprocessor(batch)