def test(self): devices = xm.get_xla_supported_devices() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) train_loader = xu.SampleGenerator( data=(torch.zeros(batch_size, 3, 224, 224), torch.zeros(batch_size, dtype=torch.int64)), sample_count=sample_count * len(devices)) def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for x, (data, target) in loader: with xu.TimedScope(msg='Training loop: ', printfn=None): optimizer.zero_grad() output = xu.timed(lambda: model(data), msg='Model: ', printfn=None) loss = xu.timed(lambda: loss_fn(output, target), msg='Loss: ', printfn=None) xu.timed(loss.backward, msg='LossBkw: ', printfn=None) xu.timed(lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None) self.assertLess(loss.cpu().item(), 3.0) model_parallel = dp.DataParallel(torchvision.models.resnet18, device_ids=devices) model_parallel(loop_fn, train_loader)
def run_gc(debug_gc=None, gc_wait=0, is_final=False): if debug_gc is None: debug_gc = xu.getenv_as('XLA_DEBUG_GC', int, 0) print_fn = xu.get_print_fn(debug_gc) gc_flags = gc.get_debug() if debug_gc > 1: gc.set_debug(gc.DEBUG_STATS | gc.DEBUG_UNCOLLECTABLE) # Run GC so that eventual XLA resource objects wrapped by std::shared_ptr<> # gets released. while True: collected = gc.collect() print_fn('GC collected %d objects' % collected) if collected == 0: break print_fn('GC found %d uncollectable objects' % len(gc.garbage)) gc.set_debug(gc_flags) # Unfortunately the Python GC does not immediately release the objects, but # it instead delegates the task to a background thread (like we do for # handles). To make things worse, there is no way to flush that work # immediately. So we look at the handle counters and we wait, up to a max, # until all the created handles have been destroyed. if (not _wait_for_released_tensors(max_wait=gc_wait, print_fn=print_fn) and is_final): _force_release_tensors() print_fn(torch_xla._XLAC._xla_metrics_report())
def __init__(self, smooth_factor=None): self._smooth_factor = xu.getenv_as( 'RATE_TRACKER_SMOOTHING', float, 0.8) if smooth_factor is None else smooth_factor self._start_time = time.time() self._partial_time = self._start_time self._partial_count = 0.0 self._partial_rate = None self._count = 0.0
def mark_step(): torch_xla._XLAC._xla_step_marker(torch_xla._XLAC._xla_get_default_device(), [], wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False)) # Only emit metrics from the first local device index, to avoid emitting the # same values from different threads. if is_master_ordinal(): ms.save_metrics()
def _mark_step(replication): devices = [] if replication: replication.enter() devices = replication.replication_devices() torch_xla._XLAC._xla_step_marker(torch_xla._XLAC._xla_get_default_device(), devices, wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False)) # Only emit metrics from the first local device index, to avoid emitting the # same values from different threads. if getattr(_TLS, 'device_index', 0) == 0: ms.save_metrics()
xla_model = MNISTComparator().to(xla_device) xla_x = x.to(xla_device) xla_model(xla_x) report = mc.compare(save_dir1.name, save_dir2.name, rtol=1e-03, atol=1e-04) if report: print(report) self.assertEqual(len(report), 0) class TestGeneric(XlaTestCase): def test_zeros_like_patch(self): a = torch.ones(3, 3) b = torch.zeros_like(a, dtype=torch.int8) self.assertEqual(b.dtype, torch.int8) self.assertEqual(b.sum().item(), 0) if __name__ == '__main__': torch.set_default_tensor_type('torch.FloatTensor') torch.manual_seed(42) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) test = unittest.main(verbosity=FLAGS.verbosity, exit=False) if xu.getenv_as('METRICS_DEBUG', bool, defval=False): print(torch_xla._XLAC._xla_metrics_report()) sys.exit(0 if test.result.wasSuccessful() else 1)
def setup(self): self.size = xu.getenv_as('ADD_MUL_DIV_SIZE', int, 100) self.a = torch.rand(self.size, self.size) self.b = torch.rand(self.size, self.size).abs() + 1.0
def __init__(self, args): self.args = args self.device = xm.xla_device() self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0) torch.manual_seed(42)
def get_ordinal(defval=0): return xu.getenv_as(xenv.ORDINAL, int, defval=defval)
def xrt_world_size(defval=1): return xu.getenv_as(xenv.WORLD_SIZE, int, defval=defval)