예제 #1
0
 def test_sign_torch_value(self):
     self.assertTrue(
         is_equal(tensor_round(self.test_torch), self.test_output_torch))
예제 #2
0
 def test_tensor_round_torch_type(self):
     self.assertIsInstance(tensor_round(self.test_torch), torch.Tensor,
                           'Output type must be torch.Tensor')
예제 #3
0
 def test_tensor_round_tf_type(self):
     self.assertIsInstance(tensor_round(self.test_tf), tf.Tensor,
                           'Output type must be tf.Tensor')
예제 #4
0
 def test_tensor_round_tf_value(self):
     self.assertTrue(
         is_equal(tensor_round(self.test_tf), self.test_output_tf))
예제 #5
0
 def test_tensor_round_np_type(self):
     self.assertIsInstance(tensor_round(self.test_np), np.ndarray,
                           'Output type must be NumPy array')