def test_decimal_friendly_collate_input_has_decimals_in_tuple(): input_batch = ([Decimal('1.0'), 1], [Decimal('1.1'), 2]) desired = [(Decimal('1.0'), Decimal('1.1')), (1, 2)] actual = decimal_friendly_collate(input_batch) assert len(actual) == 2 assert desired[0] == actual[0] np.testing.assert_equal(desired[1], actual[1].numpy())
def test_decimal_friendly_collate_input_has_decimals_in_dictionary(): desired = { 'decimal': [Decimal('1.0'), Decimal('1.1')], 'int': [1, 2] } input_batch = [ {'decimal': Decimal('1.0'), 'int': 1}, {'decimal': Decimal('1.1'), 'int': 2}, ] actual = decimal_friendly_collate(input_batch) assert len(actual) == 2 assert desired['decimal'] == actual['decimal'] np.testing.assert_equal(desired['int'], actual['int'].numpy())
def collate_fn(batch_list: List[Dict]): batch = decimal_friendly_collate(batch_list) preprocessed_batch = batch_preprocessor(batch) if use_gpu: preprocessed_batch = preprocessed_batch.cuda() return preprocessed_batch
def test_decimal_friendly_collate_empty_input(): assert decimal_friendly_collate([dict()]) == dict()
def collate_fn(batch_list: List[Dict]): batch = decimal_friendly_collate(batch_list) return batch_preprocessor(batch)