예제 #1
0
파일: test_size.py 프로젝트: zzk0/oneflow
 def test_slicing(test_case):
     size = flow.Size([2, 3, 4, 5])
     test_case.assertTrue(size[1:3] == flow.Size((3, 4)))
     test_case.assertTrue(size[1:] == flow.Size((3, 4, 5)))
     test_case.assertTrue(size[:2] == (2, 3))
     test_case.assertTrue(size[-3:] == flow.Size((3, 4, 5)))
     test_case.assertTrue(size[-3:-1] == flow.Size((3, 4)))
예제 #2
0
파일: test_size.py 프로젝트: zzk0/oneflow
 def test_equal(test_case):
     size = flow.Size((2, 3))
     test_case.assertEqual(size == (2, 3), True)
     test_case.assertEqual(size == (3, 2), False)
     test_case.assertEqual(size == flow.Size((2, 3)), True)
     test_case.assertEqual(size == flow.Size((3, 2)), False)
     test_case.assertEqual(size == [2, 3], False)
     test_case.assertEqual(size == dict(), False)
예제 #3
0
파일: test_size.py 프로젝트: zzk0/oneflow
 def test_size(test_case):
     size = flow.Size((4, 3, 10, 5))
     test_case.assertTrue(size[0] == 4)
     test_case.assertTrue(size[2] == 10)
     test_case.assertTrue(len(size) == 4)
     size = flow.Size([4, 3, 10, 5])
     test_case.assertTrue(size[0] == 4)
     test_case.assertTrue(size[2] == 10)
     test_case.assertTrue(len(size) == 4)
     size = flow.Size(size)
     test_case.assertTrue(size[0] == 4)
     test_case.assertTrue(size[2] == 10)
     test_case.assertTrue(len(size) == 4)
     test_case.assertTrue(size[-1] == 5)
     test_case.assertTrue(size[-4] == 4)
예제 #4
0
파일: test_size.py 프로젝트: zzk0/oneflow
 def test_index(test_case):
     size = flow.Size((2, 3, 2, 4, 4))
     test_case.assertEqual(size.index(2), 0)
     test_case.assertEqual(size.index(2, 0), 0)
     test_case.assertEqual(size.index(2, 0, 20), 0)
     test_case.assertEqual(size.index(2, 1, 20), 2)
     test_case.assertEqual(size.index(4), 3)
     test_case.assertEqual(size.index(4, 4), 4)
     with test_case.assertRaises(ValueError):
         size.index(4, 0, 3)
     with test_case.assertRaises(ValueError):
         size.index(5)
     with test_case.assertRaises(ValueError):
         size.index(2, 3)
예제 #5
0
파일: test_size.py 프로젝트: zzk0/oneflow
 def test_count(test_case):
     size = flow.Size((2, 2, 3, 4))
     test_case.assertEqual(size.count(1), 0)
     test_case.assertEqual(size.count(2), 2)
     test_case.assertEqual(size.count(3), 1)
     test_case.assertEqual(size.count(4), 1)
예제 #6
0
파일: test_size.py 프로젝트: zzk0/oneflow
 def test_numel(test_case):
     size = flow.Size((1, 2, 3, 4))
     test_case.assertEqual(size.numel(), 24)
예제 #7
0
파일: test_size.py 프로젝트: zzk0/oneflow
 def test_unpack(test_case):
     (one, two, three, four) = flow.Size((1, 2, 3, 4))
     test_case.assertEqual(one, 1)
     test_case.assertEqual(two, 2)
     test_case.assertEqual(three, 3)
     test_case.assertEqual(four, 4)
예제 #8
0
파일: test_size.py 프로젝트: zzk0/oneflow
def _compare_with_np(test_case, x_shape, dtype):
    x = np.random.randn(*x_shape).astype(type_name_to_np_type[dtype])
    ret = flow.Size(x_shape)
    for idx in range(0, len(ret)):
        test_case.assertEqual(ret[idx], x.shape[idx])