コード例 #1
0
ファイル: test_misc_op.py プロジェクト: zeta1999/jittor
 def test_make_grid(self):
     arr = np.random.randn(16,3,10,10)
     check_equal(torchvision.utils.make_grid(torch.Tensor(arr)), jt.make_grid(jt.array(arr)))
     check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=2), jt.make_grid(jt.array(arr), nrow=2))
     check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3), jt.make_grid(jt.array(arr), nrow=3))
     check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4), jt.make_grid(jt.array(arr), nrow=3, padding=4))
     check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, padding=4, pad_value=-1))
     check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1))
     print('pass make_grid test ...')
コード例 #2
0
 def check(shape):
     arr = np.random.randn(*shape)
     check_equal(torchvision.utils.make_grid(torch.Tensor(arr)), jt.make_grid(jt.array(arr)))