Пример #1
0
 def compareReplicated(self, model, inputs, xla_outputs):
     self.assertEqual(len(inputs), len(xla_outputs))
     for i, input in enumerate(inputs):
         expected = xu.as_list(model(*input))
         xla_output = xu.as_list(xla_outputs[i])
         self.assertEqual(len(expected), len(xla_output))
         for j, expected_tensor in enumerate(expected):
             self.assertEqualDbg(xla_output[j], expected_tensor)
Пример #2
0
 def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5):
   if device is None:
     device = xm.xla_device()
   tensors = xu.as_list(tensors)
   xla_tensors = [x.to(device) for x in tensors]
   results = xu.as_list(fn(*tensors))
   xla_results = xu.as_list(fn(*xla_tensors))
   for at, xt in zip(results, xla_results):
     self.assertEqualRel(
         self.makeComparable(xt), at, rel_err=rel_err, abs_err=abs_err)
Пример #3
0
    def checkGrad(self,
                  model,
                  inputs,
                  grad_outputs='random',
                  xla=True,
                  rel_err=1e-2,
                  abs_err=1e-5):
        # Trace and symbolically differentiate
        traced_model = torch.jit.trace(model, *inputs)
        fwd = traced_model._get_method('forward')
        xm.forward_passes(fwd.graph)

        inputs_params = inputs + list(model.parameters())
        inputs_params_buffers = inputs + list(fwd.params())

        gradient = torch._C._jit_differentiate(fwd.graph)
        xm.forward_passes(gradient.f)
        xm.backward_passes(gradient.df)

        ##############################################################
        # Run forward and backwarg graphs via jit interpreter
        exec_f = torch._C.GraphExecutor(gradient.f, False)
        exec_df = torch._C.GraphExecutor(gradient.df, False)

        # forward function
        raw_outputs = exec_f(*inputs_params_buffers)
        raw_outputs = xu.as_list(raw_outputs)
        intermediate_outputs = [
            raw_output for raw_output in raw_outputs[gradient.f_real_outputs:]
            if isinstance(raw_output, torch.Tensor)
        ]
        outputs = raw_outputs[:gradient.f_real_outputs]

        if grad_outputs == 'random':
            grad_outputs = _random_like(outputs) + _zeros_like(
                intermediate_outputs)

        raw_grad_outputs = []
        raw_grad_outputs += grad_outputs
        raw_grad_outputs += [
            inputs_params_buffers[i] for i in gradient.df_input_captured_inputs
        ]
        raw_grad_outputs += [
            raw_outputs[i] for i in gradient.df_input_captured_outputs
        ]

        grad_inputs = exec_df(*raw_grad_outputs)
        grad_inputs = xu.as_list(grad_inputs)

        ##############################################################
        # backward with XLA
        if xla:
            xla_model = torch_xla._XLAC.XlaModule(traced_model,
                                                  use_full_conv_precision=True)
            inputs_xla = [torch_xla._XLAC.XLATensor(input) for input in inputs]
            xla_model((tuple(inputs_xla)))
            grads_output_xla = [
                torch_xla._XLAC.XLATensor(grad_output)
                for grad_output in grad_outputs[:gradient.f_real_outputs]
            ]
            xla_model.backward((tuple(grads_output_xla)))
            grad_inputs_xla = [
                input_xla.grad.to_tensor() for input_xla in inputs_xla
            ]
            grad_inputs_xla.extend(
                [p.grad.to_tensor() for p in xla_model.parameters()[0]])
        ##############################################################
        # forward + backward with regular autograd / torch
        outputs_gt = model(*inputs)
        outputs_gt = xu.as_list(outputs_gt)
        grad_inputs_gt = torch.autograd.grad(outputs_gt,
                                             inputs_params,
                                             grad_outputs,
                                             only_inputs=True)
        for out_jit, out_autograd in zip(outputs, outputs_gt):
            self.assertEqualRel(out_jit,
                                out_autograd,
                                rel_err=rel_err,
                                abs_err=abs_err)

        for grad_input_jit, grad_input_autograd in zip(grad_inputs,
                                                       grad_inputs_gt):
            self.assertEqualRel(grad_input_jit,
                                grad_input_autograd,
                                rel_err=rel_err,
                                abs_err=abs_err)

        # TODO: test buffers as well (running_mean, etc.)
        if xla:
            for i, (grad_input_jit, grad_input_xla) in enumerate(
                    zip(grad_inputs, grad_inputs_xla)):
                self.assertEqualRel(grad_input_jit, grad_input_xla, rel_err,
                                    abs_err)