Esempio n. 1
0
    def assertEqual(self, x, y, *args, **kwargs):
        # HACK: Handle the dual nature of the assertEqual() PyTorch API whose first
        # argument can be prec (floating) or msg (string).
        if not args or isinstance(args[0], str):
            kwargs = self._rewrite_compare_args(kwargs)
        elif isinstance(args[0], (float, int)):
            args = [max(args[0], self.precision)] + list(args[1:])

        gmode = os.environ.get('TEST_PRINT_GRAPH', '').lower()
        if gmode == 'text':
            if type(x) == torch.Tensor and xm.is_xla_tensor(x):
                print('\nTest Graph (x):\n{}'.format(
                    torch_xla._XLAC._get_xla_tensors_text([x])),
                      file=sys.stderr)
            if type(y) == torch.Tensor and xm.is_xla_tensor(y):
                print('\nTest Graph (y):\n{}'.format(
                    torch_xla._XLAC._get_xla_tensors_text([y])),
                      file=sys.stderr)
        elif gmode == 'hlo':
            if type(x) == torch.Tensor and xm.is_xla_tensor(x):
                print('\nTest Graph (x):\n{}'.format(
                    torch_xla._XLAC._get_xla_tensors_hlo([x])),
                      file=sys.stderr)
            if type(y) == torch.Tensor and xm.is_xla_tensor(y):
                print('\nTest Graph (y):\n{}'.format(
                    torch_xla._XLAC._get_xla_tensors_hlo([y])),
                      file=sys.stderr)
        elif gmode:
            raise RuntimeError(
                'Invalid TEST_PRINT_GRAPH value: {}'.format(gmode))
        x, y = self.prepare_for_compare(x, y)
        return DeviceTypeTestBase.assertEqual(self, x, y, *args, **kwargs)
Esempio n. 2
0
 def makeComparable(self, value):
   if isinstance(value, torch.Tensor):
     if value.dtype == torch.bool:
       value = value.to(dtype=torch.uint8)
     if xm.is_xla_tensor(value.data):
       return value.data.cpu()
     return value.data
   return value
Esempio n. 3
0
 def check_fn(v):
   if select_fn(v):
     return xm.is_xla_tensor(v)
   elif isinstance(v, (list, tuple, set)):
     for x in v:
       if not check_fn(x):
         return False
   elif isinstance(v, dict):
     for k, x in v.items():
       if not check_fn(k) or not check_fn(x):
         return False
   return True
Esempio n. 4
0
 def assertEqual(self,
                 x,
                 y,
                 prec=None,
                 message='',
                 allow_inf=False,
                 **kwargs):
     if prec is None:
         prec = self.precision
     else:
         prec = max(self.precision, prec)
     gmode = os.environ.get('TEST_PRINT_GRAPH', '').lower()
     if gmode == 'text':
         if type(x) == torch.Tensor and xm.is_xla_tensor(x):
             print('\nTest Graph (x):\n{}'.format(
                 torch_xla._XLAC._get_xla_tensors_text([x])),
                   file=sys.stderr)
         if type(y) == torch.Tensor and xm.is_xla_tensor(y):
             print('\nTest Graph (y):\n{}'.format(
                 torch_xla._XLAC._get_xla_tensors_text([y])),
                   file=sys.stderr)
     elif gmode == 'hlo':
         if type(x) == torch.Tensor and xm.is_xla_tensor(x):
             print('\nTest Graph (x):\n{}'.format(
                 torch_xla._XLAC._get_xla_tensors_hlo([x])),
                   file=sys.stderr)
         if type(y) == torch.Tensor and xm.is_xla_tensor(y):
             print('\nTest Graph (y):\n{}'.format(
                 torch_xla._XLAC._get_xla_tensors_hlo([y])),
                   file=sys.stderr)
     elif gmode:
         raise RuntimeError(
             'Invalid TEST_PRINT_GRAPH value: {}'.format(gmode))
     if type(x) == torch.Tensor:
         x = x.cpu()
     if type(y) == torch.Tensor:
         y = y.cpu()
     return DeviceTypeTestBase.assertEqual(self, x, y, prec, message,
                                           allow_inf, **kwargs)