def test(self, samples_loader, eval_fn, batch_size, log_fn=print): wloader = LoaderWrapper( samples_loader, self._loader_prefetch, batch_size, num_cores=self._num_cores, devices=self._devices, fused_mode=True) wloader_cleaner = xu.Cleaner(wloader.close) test_loss = 0 count = 0 correct = 0 rate_tracker = RateTracker() for batch_number, (inputs, targets) in wloader: xla_outputs = xla_run_model( self._xla_model, inputs, devices=self._devices) for i, replica_xla_outputs in enumerate(xla_outputs): # The original model output is ordinal 1 of the returned # tuple (loss is ordinal 0). output = replica_xla_outputs[1].to_tensor() # Inputs [1] is the model target, as inputs with # fused_mode=True are [input, target]. closs, ccorrect = eval_fn(output, inputs[i][1].to_tensor()) test_loss += closs correct += ccorrect count += batch_size test_loss /= count accuracy = 100.0 * correct / count if log_fn is not None: log_fn( TestStepMetrics(test_loss, correct, count, rate_tracker.update(count), self._step)) return accuracy
def train(self, samples_loader, optimizer, batch_size, log_interval=1, log_fn=print, metrics_debug=False): wloader = LoaderWrapper(samples_loader, self._loader_prefetch, batch_size, num_cores=self._num_cores, devices=self._devices, fused_mode=True) wloader_cleaner = xu.Cleaner(wloader.close) processed_samples = 0 loss = None start_time = time.time() self._epoch += 1 for batch_number, (inputs, targets) in wloader: optimizer.zero_grad() xla_outputs = run_xla_model(self._xla_model, inputs, devices=self._devices) xla_run_grad(self._xla_model, self._get_backward_grads(xla_outputs), devices=self._devices) optimizer.step() processed_samples += self._num_cores * batch_size if (log_fn is not None and log_interval is not None and batch_number % log_interval == 0): if metrics_debug: log_fn(torch_xla._XLAC._xla_metrics_report()) loss = self._compute_loss(xla_outputs) log_fn('Train Epoch: {} [{}/{} ({:.0f}%)]\t' 'Loss: {:.6f}\tSamples/sec: {:.1f}'.format( self._epoch, processed_samples, len(samples_loader) * batch_size, 100. * batch_number * self._num_cores / len(samples_loader), loss, processed_samples / (time.time() - start_time))) return loss
def train(self, samples_loader, optimizer, batch_size, log_interval=1, log_fn=print, metrics_debug=False): wloader = LoaderWrapper( samples_loader, self._loader_prefetch, batch_size, num_cores=self._num_cores, devices=self._devices, fused_mode=True) wloader_cleaner = xu.Cleaner(wloader.close) optimizer.zero_grad() loss = None rate_tracker = RateTracker() self._epoch += 1 for batch_number, (inputs, targets) in wloader: self._step += 1 xla_outputs = xla_run_model( self._xla_model, inputs, devices=self._devices) xla_run_grad( self._xla_model, self._get_backward_grads(xla_outputs), devices=self._devices) optimizer.step() if (log_fn is not None and log_interval is not None and batch_number % log_interval == 0): if metrics_debug: log_fn(torch_xla._XLAC._xla_metrics_report()) loss = self._compute_loss(xla_outputs) rate_tracker.update(self._num_cores * batch_size * (batch_number + 1)) log_fn( TrainStepMetrics(self._epoch, self._num_cores, batch_number, len(samples_loader), batch_size, loss, rate_tracker.rate(), self._step)) return loss
def test(self, samples_loader, eval_fn, batch_size, log_fn=print): wloader = LoaderWrapper(samples_loader, self._loader_prefetch, batch_size, num_cores=self._num_cores, devices=self._devices, fused_mode=True) wloader_cleaner = xu.Cleaner(wloader.close) test_loss = 0 count = 0 correct = 0 start_time = time.time() with torch.no_grad(): for batch_number, (inputs, targets) in wloader: xla_outputs = run_xla_model(self._xla_model, inputs, devices=self._devices) for i, replica_xla_outputs in enumerate(xla_outputs): # The original model output is ordinal 1 of the returned # tuple (loss is ordinal 0). output = replica_xla_outputs[1].to_tensor() # Inputs [1] is the model target, as inputs with # fused_mode=True are [input, target]. closs, ccorrect = eval_fn(output, inputs[i][1].to_tensor()) test_loss += closs correct += ccorrect count += batch_size test_loss /= count accuracy = 100.0 * correct / count if log_fn is not None: log_fn( '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), ' 'Samples/sec: {:.1f}\n'.format( test_loss, correct, count, accuracy, count / (time.time() - start_time))) return accuracy