Ejemplo n.º 1
0
 def test_torch_input(self):
     watch = Watch(inputs='x')
     x = self.torch_data * self.torch_data
     output = watch.forward(data=[self.torch_data, x], state={'tape': None})
     self.assertTrue(is_equal(output, self.torch_output))
Ejemplo n.º 2
0
 def test_tf_input(self):
     watch = Watch(inputs='x')
     with tf.GradientTape(persistent=True) as tape:
         x = self.tf_data * self.tf_data
         output = watch.forward(data=[self.tf_data, x], state={'tape': tape})
     self.assertTrue(is_equal(output, self.tf_output))