forked from luizgh/datasetCreator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
testBatchBuilder.py
69 lines (53 loc) · 2.92 KB
/
testBatchBuilder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import unittest
from mock import Mock, call
from BatchBuilder import BatchBuilder
class testBatchBuilder(unittest.TestCase):
def setUp(self):
self.singleBatchBuilder = Mock()
self.metaBatchBuilder = Mock()
def testSingleTrainingBatch(self):
batch1, batch2, batch3, meta = (Mock(), Mock(), Mock(), Mock())
self.singleBatchBuilder.build.side_effect = [batch1, batch2, batch3]
self.metaBatchBuilder.build.return_value = meta
train, valid, test = ([Mock()], Mock(), Mock())
dataset = (train, valid, test)
classes = [0]
classNames = ['class']
self.target = BatchBuilder(self.singleBatchBuilder,
self.metaBatchBuilder,
nTrainingBatches = 1)
result = self.target.build(dataset, classes, classNames)
self.singleBatchBuilder.build.assert_has_calls([call(train, classes),
call(valid,classes),
call(test,classes)])
self.metaBatchBuilder.build.assert_called_with(dataset, classes, classNames)
self.assertEqual({'data_batch_1' : batch1,
'data_batch_2' : batch2,
'data_batch_3' : batch3,
'batches.meta' : meta}, result)
def testMultipleTrainingBatches(self):
self.singleBatchBuilder.build.side_effect = batches = [Mock() for i in range(5)]
self.metaBatchBuilder.build.return_value = meta = Mock()
train = [Mock() for i in range(11)]
valid, test = (Mock(), Mock())
dataset = (train, valid, test)
classes = [0]
classNames = ['class']
self.target = BatchBuilder(self.singleBatchBuilder,
self.metaBatchBuilder,
nTrainingBatches = 3)
result = self.target.build(dataset, classes, classNames)
self.singleBatchBuilder.build.assert_has_calls([call(train[:3], classes),
call(train[3:6], classes),
call(train[6:], classes),
call(valid,classes),
call(test,classes)])
self.metaBatchBuilder.build.assert_called_with(dataset, classes, classNames)
self.assertEqual({'data_batch_1' : batches[0],
'data_batch_2' : batches[1],
'data_batch_3' : batches[2],
'data_batch_4' : batches[3],
'data_batch_5' : batches[4],
'batches.meta' : meta}, result)
if __name__ == '__main__':
unittest.main()