def test_torch_input(self): watch = Watch(inputs='x') x = self.torch_data * self.torch_data output = watch.forward(data=[self.torch_data, x], state={'tape': None}) self.assertTrue(is_equal(output, self.torch_output))
def test_tf_input(self): watch = Watch(inputs='x') with tf.GradientTape(persistent=True) as tape: x = self.tf_data * self.tf_data output = watch.forward(data=[self.tf_data, x], state={'tape': tape}) self.assertTrue(is_equal(output, self.tf_output))