Exemple #1
0
 def vjp_test():
     nonlocal v
     xs = self.gen_inputs(inputs)
     if v is not None:
         v = self.gen_inputs(v)
         outputs, inputs_grad = vjp(func,
                                    xs,
                                    v,
                                    create_graph=create_graph,
                                    allow_unused=allow_unused)
     else:
         outputs, inputs_grad = vjp(func,
                                    xs,
                                    create_graph=create_graph,
                                    allow_unused=allow_unused)
     return outputs, inputs_grad
Exemple #2
0
 def test_vjp_nested_no_create_graph(self):
     x = self.gen_input('a')
     test_cases = [
         [nested(x), 'a'],  #noqa
     ]
     for f, inputs in test_cases:
         vjp, grad = self.gen_test_pairs(f, inputs)
         vjp_result, grad_result = vjp(), grad()
         self.check_results(grad_result, vjp_result)
Exemple #3
0
 def test_vjp_i2o2_omitting_v_no_create_graph(self):
     test_cases = [
         [o2, ['A', 'A']],  #noqa
     ]  #noqa
     for f, inputs in test_cases:
         inputs = self.gen_inputs(inputs)
         vjp, grad = self.gen_test_pairs(f, inputs)
         vjp_result, grad_result = vjp(), grad()
         self.check_results(grad_result, vjp_result)
Exemple #4
0
 def test_vjp_i2o1_no_create_graph(self):
     test_cases = [
         [matmul, ['A', 'B']],  #noqa
         [mul, ['b', 'c']],  #noqa
     ]  #noqa
     for f, inputs in test_cases:
         vjp, grad = self.gen_test_pairs(f, inputs)
         vjp_result, grad_result = vjp(), grad()
         self.check_results(grad_result, vjp_result)
Exemple #5
0
 def test_vjp_i1o1_no_create_graph(self):
     test_cases = [
         [reduce, 'A'],  #noqa
         [reduce_dim, 'A'],  #noqa
     ]  #noqa
     for f, inputs in test_cases:
         vjp, grad = self.gen_test_pairs(f, inputs)
         vjp_result, grad_result = vjp(), grad()
         self.check_results(grad_result, vjp_result)
Exemple #6
0
 def test_vjp_allowunused_no_create_graph(self):
     x, y = self.gen_input('A'), self.gen_input('a')
     vjp, grad = self.gen_test_pairs(unuse, [x, y], allow_unused=True)
     vjp_result, grad_result = vjp(), grad()
     self.check_results(grad_result, vjp_result)