Esempio n. 1
0
    def test(self):
        devices = xm.get_xla_supported_devices()
        batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
        sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(batch_size, 3, 224,
                              224), torch.zeros(batch_size,
                                                dtype=torch.int64)),
            sample_count=sample_count * len(devices))

        def loop_fn(model, loader, device, context):
            loss_fn = nn.NLLLoss()
            optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

            for x, (data, target) in loader:
                with xu.TimedScope(msg='Training loop: ', printfn=None):
                    optimizer.zero_grad()
                    output = xu.timed(lambda: model(data),
                                      msg='Model: ',
                                      printfn=None)
                    loss = xu.timed(lambda: loss_fn(output, target),
                                    msg='Loss: ',
                                    printfn=None)
                    xu.timed(loss.backward, msg='LossBkw: ', printfn=None)
                    xu.timed(lambda: xm.optimizer_step(optimizer),
                             msg='Step: ',
                             printfn=None)
                    self.assertLess(loss.cpu().item(), 3.0)

        model_parallel = dp.DataParallel(torchvision.models.resnet18,
                                         device_ids=devices)
        model_parallel(loop_fn, train_loader)
Esempio n. 2
0
def run_gc(debug_gc=None, gc_wait=0, is_final=False):
    if debug_gc is None:
        debug_gc = xu.getenv_as('XLA_DEBUG_GC', int, 0)
    print_fn = xu.get_print_fn(debug_gc)
    gc_flags = gc.get_debug()
    if debug_gc > 1:
        gc.set_debug(gc.DEBUG_STATS | gc.DEBUG_UNCOLLECTABLE)
    # Run GC so that eventual XLA resource objects wrapped by std::shared_ptr<>
    # gets released.
    while True:
        collected = gc.collect()
        print_fn('GC collected %d objects' % collected)
        if collected == 0:
            break
    print_fn('GC found %d uncollectable objects' % len(gc.garbage))
    gc.set_debug(gc_flags)
    # Unfortunately the Python GC does not immediately release the objects, but
    # it instead delegates the task to a background thread (like we do for
    # handles). To make things worse, there is no way to flush that work
    # immediately. So we look at the handle counters and we wait, up to a max,
    # until all the created handles have been destroyed.
    if (not _wait_for_released_tensors(max_wait=gc_wait, print_fn=print_fn)
            and is_final):
        _force_release_tensors()
    print_fn(torch_xla._XLAC._xla_metrics_report())
Esempio n. 3
0
 def __init__(self, smooth_factor=None):
     self._smooth_factor = xu.getenv_as(
         'RATE_TRACKER_SMOOTHING', float,
         0.8) if smooth_factor is None else smooth_factor
     self._start_time = time.time()
     self._partial_time = self._start_time
     self._partial_count = 0.0
     self._partial_rate = None
     self._count = 0.0
Esempio n. 4
0
def mark_step():
    torch_xla._XLAC._xla_step_marker(torch_xla._XLAC._xla_get_default_device(),
                                     [],
                                     wait=xu.getenv_as('XLA_SYNC_WAIT', bool,
                                                       False))
    # Only emit metrics from the first local device index, to avoid emitting the
    # same values from different threads.
    if is_master_ordinal():
        ms.save_metrics()
Esempio n. 5
0
def _mark_step(replication):
    devices = []
    if replication:
        replication.enter()
        devices = replication.replication_devices()
    torch_xla._XLAC._xla_step_marker(torch_xla._XLAC._xla_get_default_device(),
                                     devices,
                                     wait=xu.getenv_as('XLA_SYNC_WAIT', bool,
                                                       False))
    # Only emit metrics from the first local device index, to avoid emitting the
    # same values from different threads.
    if getattr(_TLS, 'device_index', 0) == 0:
        ms.save_metrics()
Esempio n. 6
0
        xla_model = MNISTComparator().to(xla_device)
        xla_x = x.to(xla_device)
        xla_model(xla_x)

        report = mc.compare(save_dir1.name,
                            save_dir2.name,
                            rtol=1e-03,
                            atol=1e-04)
        if report:
            print(report)
        self.assertEqual(len(report), 0)


class TestGeneric(XlaTestCase):
    def test_zeros_like_patch(self):
        a = torch.ones(3, 3)
        b = torch.zeros_like(a, dtype=torch.int8)
        self.assertEqual(b.dtype, torch.int8)
        self.assertEqual(b.sum().item(), 0)


if __name__ == '__main__':
    torch.set_default_tensor_type('torch.FloatTensor')
    torch.manual_seed(42)
    torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
        use_full_mat_mul_precision=True)
    test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
    if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
        print(torch_xla._XLAC._xla_metrics_report())
    sys.exit(0 if test.result.wasSuccessful() else 1)
Esempio n. 7
0
 def setup(self):
     self.size = xu.getenv_as('ADD_MUL_DIV_SIZE', int, 100)
     self.a = torch.rand(self.size, self.size)
     self.b = torch.rand(self.size, self.size).abs() + 1.0
Esempio n. 8
0
 def __init__(self, args):
     self.args = args
     self.device = xm.xla_device()
     self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0)
     torch.manual_seed(42)
Esempio n. 9
0
def get_ordinal(defval=0):
    return xu.getenv_as(xenv.ORDINAL, int, defval=defval)
Esempio n. 10
0
def xrt_world_size(defval=1):
    return xu.getenv_as(xenv.WORLD_SIZE, int, defval=defval)