def compareReplicated(self, model, inputs, xla_outputs): self.assertEqual(len(inputs), len(xla_outputs)) for i, input in enumerate(inputs): expected = xu.as_list(model(*input)) xla_output = xu.as_list(xla_outputs[i]) self.assertEqual(len(expected), len(xla_output)) for j, expected_tensor in enumerate(expected): self.assertEqualDbg(xla_output[j], expected_tensor)
def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5): if device is None: device = xm.xla_device() tensors = xu.as_list(tensors) xla_tensors = [x.to(device) for x in tensors] results = xu.as_list(fn(*tensors)) xla_results = xu.as_list(fn(*xla_tensors)) for at, xt in zip(results, xla_results): self.assertEqualRel( self.makeComparable(xt), at, rel_err=rel_err, abs_err=abs_err)
def checkGrad(self, model, inputs, grad_outputs='random', xla=True, rel_err=1e-2, abs_err=1e-5): # Trace and symbolically differentiate traced_model = torch.jit.trace(model, *inputs) fwd = traced_model._get_method('forward') xm.forward_passes(fwd.graph) inputs_params = inputs + list(model.parameters()) inputs_params_buffers = inputs + list(fwd.params()) gradient = torch._C._jit_differentiate(fwd.graph) xm.forward_passes(gradient.f) xm.backward_passes(gradient.df) ############################################################## # Run forward and backwarg graphs via jit interpreter exec_f = torch._C.GraphExecutor(gradient.f, False) exec_df = torch._C.GraphExecutor(gradient.df, False) # forward function raw_outputs = exec_f(*inputs_params_buffers) raw_outputs = xu.as_list(raw_outputs) intermediate_outputs = [ raw_output for raw_output in raw_outputs[gradient.f_real_outputs:] if isinstance(raw_output, torch.Tensor) ] outputs = raw_outputs[:gradient.f_real_outputs] if grad_outputs == 'random': grad_outputs = _random_like(outputs) + _zeros_like( intermediate_outputs) raw_grad_outputs = [] raw_grad_outputs += grad_outputs raw_grad_outputs += [ inputs_params_buffers[i] for i in gradient.df_input_captured_inputs ] raw_grad_outputs += [ raw_outputs[i] for i in gradient.df_input_captured_outputs ] grad_inputs = exec_df(*raw_grad_outputs) grad_inputs = xu.as_list(grad_inputs) ############################################################## # backward with XLA if xla: xla_model = torch_xla._XLAC.XlaModule(traced_model, use_full_conv_precision=True) inputs_xla = [torch_xla._XLAC.XLATensor(input) for input in inputs] xla_model((tuple(inputs_xla))) grads_output_xla = [ torch_xla._XLAC.XLATensor(grad_output) for grad_output in grad_outputs[:gradient.f_real_outputs] ] xla_model.backward((tuple(grads_output_xla))) grad_inputs_xla = [ input_xla.grad.to_tensor() for input_xla in inputs_xla ] grad_inputs_xla.extend( [p.grad.to_tensor() for p in xla_model.parameters()[0]]) ############################################################## # forward + backward with regular autograd / torch outputs_gt = model(*inputs) outputs_gt = xu.as_list(outputs_gt) grad_inputs_gt = torch.autograd.grad(outputs_gt, inputs_params, grad_outputs, only_inputs=True) for out_jit, out_autograd in zip(outputs, outputs_gt): self.assertEqualRel(out_jit, out_autograd, rel_err=rel_err, abs_err=abs_err) for grad_input_jit, grad_input_autograd in zip(grad_inputs, grad_inputs_gt): self.assertEqualRel(grad_input_jit, grad_input_autograd, rel_err=rel_err, abs_err=abs_err) # TODO: test buffers as well (running_mean, etc.) if xla: for i, (grad_input_jit, grad_input_xla) in enumerate( zip(grad_inputs, grad_inputs_xla)): self.assertEqualRel(grad_input_jit, grad_input_xla, rel_err, abs_err)