def assertEqual(self, x, y, *args, **kwargs): # HACK: Handle the dual nature of the assertEqual() PyTorch API whose first # argument can be prec (floating) or msg (string). if not args or isinstance(args[0], str): kwargs = self._rewrite_compare_args(kwargs) elif isinstance(args[0], (float, int)): args = [max(args[0], self.precision)] + list(args[1:]) gmode = os.environ.get('TEST_PRINT_GRAPH', '').lower() if gmode == 'text': if type(x) == torch.Tensor and xm.is_xla_tensor(x): print('\nTest Graph (x):\n{}'.format( torch_xla._XLAC._get_xla_tensors_text([x])), file=sys.stderr) if type(y) == torch.Tensor and xm.is_xla_tensor(y): print('\nTest Graph (y):\n{}'.format( torch_xla._XLAC._get_xla_tensors_text([y])), file=sys.stderr) elif gmode == 'hlo': if type(x) == torch.Tensor and xm.is_xla_tensor(x): print('\nTest Graph (x):\n{}'.format( torch_xla._XLAC._get_xla_tensors_hlo([x])), file=sys.stderr) if type(y) == torch.Tensor and xm.is_xla_tensor(y): print('\nTest Graph (y):\n{}'.format( torch_xla._XLAC._get_xla_tensors_hlo([y])), file=sys.stderr) elif gmode: raise RuntimeError( 'Invalid TEST_PRINT_GRAPH value: {}'.format(gmode)) x, y = self.prepare_for_compare(x, y) return DeviceTypeTestBase.assertEqual(self, x, y, *args, **kwargs)
def makeComparable(self, value): if isinstance(value, torch.Tensor): if value.dtype == torch.bool: value = value.to(dtype=torch.uint8) if xm.is_xla_tensor(value.data): return value.data.cpu() return value.data return value
def check_fn(v): if select_fn(v): return xm.is_xla_tensor(v) elif isinstance(v, (list, tuple, set)): for x in v: if not check_fn(x): return False elif isinstance(v, dict): for k, x in v.items(): if not check_fn(k) or not check_fn(x): return False return True
def assertEqual(self, x, y, prec=None, message='', allow_inf=False, **kwargs): if prec is None: prec = self.precision else: prec = max(self.precision, prec) gmode = os.environ.get('TEST_PRINT_GRAPH', '').lower() if gmode == 'text': if type(x) == torch.Tensor and xm.is_xla_tensor(x): print('\nTest Graph (x):\n{}'.format( torch_xla._XLAC._get_xla_tensors_text([x])), file=sys.stderr) if type(y) == torch.Tensor and xm.is_xla_tensor(y): print('\nTest Graph (y):\n{}'.format( torch_xla._XLAC._get_xla_tensors_text([y])), file=sys.stderr) elif gmode == 'hlo': if type(x) == torch.Tensor and xm.is_xla_tensor(x): print('\nTest Graph (x):\n{}'.format( torch_xla._XLAC._get_xla_tensors_hlo([x])), file=sys.stderr) if type(y) == torch.Tensor and xm.is_xla_tensor(y): print('\nTest Graph (y):\n{}'.format( torch_xla._XLAC._get_xla_tensors_hlo([y])), file=sys.stderr) elif gmode: raise RuntimeError( 'Invalid TEST_PRINT_GRAPH value: {}'.format(gmode)) if type(x) == torch.Tensor: x = x.cpu() if type(y) == torch.Tensor: y = y.cpu() return DeviceTypeTestBase.assertEqual(self, x, y, prec, message, allow_inf, **kwargs)