示例#1
0
文件: xla_model.py 项目: yyht/xla
 def test(self, samples_loader, eval_fn, batch_size, log_fn=print):
   wloader = LoaderWrapper(
       samples_loader,
       self._loader_prefetch,
       batch_size,
       num_cores=self._num_cores,
       devices=self._devices,
       fused_mode=True)
   wloader_cleaner = xu.Cleaner(wloader.close)
   test_loss = 0
   count = 0
   correct = 0
   rate_tracker = RateTracker()
   for batch_number, (inputs, targets) in wloader:
     xla_outputs = xla_run_model(
         self._xla_model, inputs, devices=self._devices)
     for i, replica_xla_outputs in enumerate(xla_outputs):
       # The original model output is ordinal 1 of the returned
       # tuple (loss is ordinal 0).
       output = replica_xla_outputs[1].to_tensor()
       # Inputs [1] is the model target, as inputs with
       # fused_mode=True are [input, target].
       closs, ccorrect = eval_fn(output, inputs[i][1].to_tensor())
       test_loss += closs
       correct += ccorrect
       count += batch_size
   test_loss /= count
   accuracy = 100.0 * correct / count
   if log_fn is not None:
     log_fn(
         TestStepMetrics(test_loss, correct, count, rate_tracker.update(count),
                         self._step))
   return accuracy
示例#2
0
 def train(self,
           samples_loader,
           optimizer,
           batch_size,
           log_interval=1,
           log_fn=print,
           metrics_debug=False):
     wloader = LoaderWrapper(samples_loader,
                             self._loader_prefetch,
                             batch_size,
                             num_cores=self._num_cores,
                             devices=self._devices,
                             fused_mode=True)
     wloader_cleaner = xu.Cleaner(wloader.close)
     processed_samples = 0
     loss = None
     start_time = time.time()
     self._epoch += 1
     for batch_number, (inputs, targets) in wloader:
         optimizer.zero_grad()
         xla_outputs = run_xla_model(self._xla_model,
                                     inputs,
                                     devices=self._devices)
         xla_run_grad(self._xla_model,
                      self._get_backward_grads(xla_outputs),
                      devices=self._devices)
         optimizer.step()
         processed_samples += self._num_cores * batch_size
         if (log_fn is not None and log_interval is not None
                 and batch_number % log_interval == 0):
             if metrics_debug:
                 log_fn(torch_xla._XLAC._xla_metrics_report())
             loss = self._compute_loss(xla_outputs)
             log_fn('Train Epoch: {} [{}/{} ({:.0f}%)]\t'
                    'Loss: {:.6f}\tSamples/sec: {:.1f}'.format(
                        self._epoch, processed_samples,
                        len(samples_loader) * batch_size,
                        100. * batch_number * self._num_cores /
                        len(samples_loader), loss,
                        processed_samples / (time.time() - start_time)))
     return loss
示例#3
0
文件: xla_model.py 项目: yyht/xla
 def train(self,
           samples_loader,
           optimizer,
           batch_size,
           log_interval=1,
           log_fn=print,
           metrics_debug=False):
   wloader = LoaderWrapper(
       samples_loader,
       self._loader_prefetch,
       batch_size,
       num_cores=self._num_cores,
       devices=self._devices,
       fused_mode=True)
   wloader_cleaner = xu.Cleaner(wloader.close)
   optimizer.zero_grad()
   loss = None
   rate_tracker = RateTracker()
   self._epoch += 1
   for batch_number, (inputs, targets) in wloader:
     self._step += 1
     xla_outputs = xla_run_model(
         self._xla_model, inputs, devices=self._devices)
     xla_run_grad(
         self._xla_model,
         self._get_backward_grads(xla_outputs),
         devices=self._devices)
     optimizer.step()
     if (log_fn is not None and log_interval is not None and
         batch_number % log_interval == 0):
       if metrics_debug:
         log_fn(torch_xla._XLAC._xla_metrics_report())
       loss = self._compute_loss(xla_outputs)
       rate_tracker.update(self._num_cores * batch_size * (batch_number + 1))
       log_fn(
           TrainStepMetrics(self._epoch, self._num_cores, batch_number,
                            len(samples_loader), batch_size, loss,
                            rate_tracker.rate(), self._step))
   return loss
示例#4
0
 def test(self, samples_loader, eval_fn, batch_size, log_fn=print):
     wloader = LoaderWrapper(samples_loader,
                             self._loader_prefetch,
                             batch_size,
                             num_cores=self._num_cores,
                             devices=self._devices,
                             fused_mode=True)
     wloader_cleaner = xu.Cleaner(wloader.close)
     test_loss = 0
     count = 0
     correct = 0
     start_time = time.time()
     with torch.no_grad():
         for batch_number, (inputs, targets) in wloader:
             xla_outputs = run_xla_model(self._xla_model,
                                         inputs,
                                         devices=self._devices)
             for i, replica_xla_outputs in enumerate(xla_outputs):
                 # The original model output is ordinal 1 of the returned
                 # tuple (loss is ordinal 0).
                 output = replica_xla_outputs[1].to_tensor()
                 # Inputs [1] is the model target, as inputs with
                 # fused_mode=True are [input, target].
                 closs, ccorrect = eval_fn(output, inputs[i][1].to_tensor())
                 test_loss += closs
                 correct += ccorrect
                 count += batch_size
     test_loss /= count
     accuracy = 100.0 * correct / count
     if log_fn is not None:
         log_fn(
             '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), '
             'Samples/sec: {:.1f}\n'.format(
                 test_loss, correct, count, accuracy,
                 count / (time.time() - start_time)))
     return accuracy