def vjp_test(): nonlocal v xs = self.gen_inputs(inputs) if v is not None: v = self.gen_inputs(v) outputs, inputs_grad = vjp(func, xs, v, create_graph=create_graph, allow_unused=allow_unused) else: outputs, inputs_grad = vjp(func, xs, create_graph=create_graph, allow_unused=allow_unused) return outputs, inputs_grad
def test_vjp_nested_no_create_graph(self): x = self.gen_input('a') test_cases = [ [nested(x), 'a'], #noqa ] for f, inputs in test_cases: vjp, grad = self.gen_test_pairs(f, inputs) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result)
def test_vjp_i2o2_omitting_v_no_create_graph(self): test_cases = [ [o2, ['A', 'A']], #noqa ] #noqa for f, inputs in test_cases: inputs = self.gen_inputs(inputs) vjp, grad = self.gen_test_pairs(f, inputs) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result)
def test_vjp_i2o1_no_create_graph(self): test_cases = [ [matmul, ['A', 'B']], #noqa [mul, ['b', 'c']], #noqa ] #noqa for f, inputs in test_cases: vjp, grad = self.gen_test_pairs(f, inputs) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result)
def test_vjp_i1o1_no_create_graph(self): test_cases = [ [reduce, 'A'], #noqa [reduce_dim, 'A'], #noqa ] #noqa for f, inputs in test_cases: vjp, grad = self.gen_test_pairs(f, inputs) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result)
def test_vjp_allowunused_no_create_graph(self): x, y = self.gen_input('A'), self.gen_input('a') vjp, grad = self.gen_test_pairs(unuse, [x, y], allow_unused=True) vjp_result, grad_result = vjp(), grad() self.check_results(grad_result, vjp_result)