Exemplo n.º 1
0
 def threadfn(i):
     device = devices[i]
     xdevices = [device] * len(send_list[i])
     for n in range(0, args.test_count):
         with xu.TimedScope(msg='Send[{}][{}]: '.format(i, n),
                            printfn=print):
             _ = torch_xla._XLAC._xla_tensors_from_aten(
                 send_list[i], xdevices)
Exemplo n.º 2
0
    def loop_fn(model, loader, device, context):
      loss_fn = nn.NLLLoss()
      optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

      for x, (data, target) in loader:
        with xu.TimedScope(msg='Training loop: ', printfn=None):
          optimizer.zero_grad()
          output = xu.timed(lambda: model(data), msg='Model: ', printfn=None)
          loss = xu.timed(
              lambda: loss_fn(output, target), msg='Loss: ', printfn=None)
          xu.timed(loss.backward, msg='LossBkw: ', printfn=None)
          xu.timed(
              lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None)
          self.assertLess(loss.cpu().item(), 3.0)