Esempio n. 1
0
def _xla_run(model, input, device='TPU'):
    if isinstance(input, (tuple, list)):
        devices = ['{}:{}'.format(device, n) for n in range(0, len(input))]
        xla_model = xm.XlaModel(model,
                                input[0],
                                num_cores=len(input),
                                devices=devices,
                                full_conv_precision=True)
        output_xla = xla_model(*input)
        return xm.convert_to_tensors(output_xla)
    else:
        xla_model = xm.XlaModel(model, [input], full_conv_precision=True)
        output_xla = xla_model(input)
        return output_xla[0]
Esempio n. 2
0
 def compareModel(self, model, input, rel_err=0.05, abs_err=1e-4):
     xla_model = xm.XlaModel(model, [input], full_conv_precision=True)
     output_xla = xla_model(input)
     output = model(input)
     self.assertEqualRel(output,
                         xm.convert_to_tensors(output_xla)[0],
                         rel_err=rel_err,
                         abs_err=abs_err)
     grad_output = _gen_tensor(*output.shape)  # random gradients
     grad_output.grad = grad_output.data
     output.backward(grad_output)
     xla_model.backward([grad_output])
     xla_updated_params = [
         p.grad.to_tensor() for p in xla_model.parameters()[0]
     ]
     updated_params = [p.grad for p in model.parameters()]
     self.assertEqual(len(xla_updated_params), len(updated_params))
     for i in range(0, len(updated_params)):
         self.assertEqualRel(xla_updated_params[i],
                             updated_params[i],
                             rel_err=rel_err,
                             abs_err=abs_err)