def _validation(self) -> Dict[str, float]: logger.info("Validating") val_data = self._validation_dataset if isinstance(val_data, dataset.ActuatedTrajectoryDataset): val_data = transform.odepred_transform(val_data, self._pred_horizon) val_generator = torch.utils.data.DataLoader(val_data, shuffle=False, batch_size=self._batch_size) loss_ls = [] loss_info_ls = [] for qs, vs, us in tqdm.tqdm(val_generator, total=len(val_generator)): with torch.no_grad(): loss, loss_info = compute_qvloss( ActuatedODEWrapper(self.model), torch.stack(qs).to(self._device), torch.stack(vs).to(self._device), torch.stack(us).to(self._device), dt=self._dt, vlambda=self._vlambda, method=self._integration_method) loss_ls.append(loss.cpu().detach().item()) loss_info_ls.append(loss_info) metrics = {} loss_info = nested.zip(*loss_info_ls) metrics['loss/mean'] = np.mean(loss_ls) metrics['loss/std'] = np.std(loss_ls) metrics['log10loss/mean'] = np.mean(np.log10(loss_ls)) metrics['log10loss/std'] = np.std(np.log10(loss_ls)) for k, val in loss_info.items(): metrics['loss/{}/mean'.format(k)] = np.mean(val) metrics['loss/{}/std'.format(k)] = np.std(val) return metrics
def _train_epoch(self, epoch: int) -> Dict[str, float]: logger.info("Epoch {}/{}".format(epoch, self._num_epochs - 1)) peak_mem_usage = utils.peak_memory_mb() logger.info(f"Peak CPU memory usage MB: {peak_mem_usage}") loss_ls = [] loss_info_ls = [] losstimer_ls = [] gradtimer_ls = [] train_data = self._train_dataset if isinstance(train_data, dataset.ActuatedTrajectoryDataset): train_data = transform.odepred_transform(train_data, self._pred_horizon) train_generator = torch.utils.data.DataLoader(train_data, shuffle=self._shuffle, batch_size=self._batch_size) for qs, vs, us in tqdm.tqdm(train_generator, total=len(train_generator)): self.optimizer.zero_grad() with utils.Timer() as losstime: loss, loss_info = compute_qvloss( ActuatedODEWrapper(self.model), torch.stack(qs).to(self._device), torch.stack(vs).to(self._device), torch.stack(us).to(self._device), dt=self._dt, vlambda=self._vlambda, method=self._integration_method) with utils.Timer() as gradtime: loss.backward() self.optimizer.step() loss_ls.append(loss.cpu().detach().numpy()) loss_info_ls.append(loss_info) losstimer_ls.append(losstime.dt) gradtimer_ls.append(gradtime.dt) metrics = {} loss_info = nested.zip(*loss_info_ls) metrics['cpu_memory_MB'] = peak_mem_usage metrics['loss/mean'] = np.mean(loss_ls) metrics['loss/std'] = np.std(loss_ls) metrics['log10loss/mean'] = np.mean(np.log10(loss_ls)) metrics['log10loss/std'] = np.std(np.log10(loss_ls)) for k, val in loss_info.items(): metrics['loss/{}/mean'.format(k)] = np.mean(val) metrics['loss/{}/std'.format(k)] = np.std(val) metrics['time/loss/mean'] = np.mean(losstimer_ls) metrics['time/loss/std'] = np.std(losstimer_ls) metrics['time/loss/max'] = np.max(losstimer_ls) metrics['time/loss/min'] = np.min(losstimer_ls) metrics['time/grad/mean'] = np.mean(gradtimer_ls) metrics['time/grad/std'] = np.std(gradtimer_ls) metrics['time/grad/max'] = np.max(gradtimer_ls) metrics['time/grad/min'] = np.min(gradtimer_ls) if self._learning_rate_scheduler: metrics['lr'] = self._learning_rate_scheduler.get_lr()[0] return metrics
def _validation(self) -> Dict[str, float]: logger.info("Validating") q_b, v_b, qddot_b, F_b = map(lambda b: b.to(self._device), self._valid_batch) with torch.no_grad(): loss, loss_vec = compute_DeLaNloss(self.model, q_b, v_b, qddot_b, F_b) metrics = {} metrics['loss_F'] = loss.cpu().detach().item() for i in range(self.model._qdim): metrics[f'loss_F_{i}'] = loss_vec[i].cpu().detach().item() return metrics
def _train_batch(self, batch_i: int) -> Dict[str, float]: logger.info("Training batch {}/{}".format(batch_i, self._num_train_batches - 1)) self._last_train_batch = next(self._train_datagen) self.optimizer.zero_grad() q_b, v_b, qddot_b, F_b = map(lambda b: b.to(self._device), self._last_train_batch) loss, loss_vec = compute_DeLaNloss(self.model, q_b, v_b, qddot_b, F_b) loss.backward() self.optimizer.step() metrics = {} metrics['loss_F_tot'] = loss.cpu().detach().item() for i in range(self.model._qdim): metrics[f'loss_F_{i}'] = loss_vec[i].cpu().detach().item() return metrics
def _parameter_and_gradient_statistics(self) -> None: for name, param in self.model.named_parameters(): logger.logkv("parameter_mean/" + name, param.data.mean().item()) logger.logkv("parameter_std/" + name, param.data.std().item()) logger.logkv("parameter_norm/" + name, param.data.norm().item()) if param.grad is not None: grad_data = param.grad.data # skip empty gradients if torch.prod(torch.tensor(grad_data.shape)).item() > 0: logger.logkv("gradient_mean/" + name, grad_data.mean().item()) logger.logkv("gradient_std/" + name, grad_data.std().item()) logger.logkv("gradient_norm/" + name, grad_data.norm().item()) else: logger.info("No gradient for {}, skipping.".format(name))
def _save_checkpoint(self, epoch: int) -> None: model_path = Path(self._logdir) / "model_state_epoch_{}.th".format(epoch) model_state = self.model.state_dict() torch.save(model_state, model_path) training_state = { 'epoch': epoch, 'metric_tracker': self._metric_tracker.state_dict(), 'optimizer': self.optimizer.state_dict(), } training_path = Path(self._logdir) / "training_state_epoch_{}.th".format(epoch) torch.save(training_state, training_path) if self._metric_tracker.is_best_so_far(): logger.info("Best validation performance so far. Copying weights to {}/best.th".format( self._logdir)) shutil.copyfile(model_path, Path(self._logdir) / "best.th")
def train(self): logger.info("Begin training...") metrics = {} for batch_i in tqdm.tqdm(range(self._num_train_batches)): train_metrics = self._train_batch(batch_i) valid_metrics = self._validation() for k, v in train_metrics.items(): logger.logkv("training/{}".format(k), v) for k, v in valid_metrics.items(): logger.logkv("validation/{}".format(k), v) # checkpoint if self._logdir: if (batch_i % self._ckpt_interval == 0) or ( batch_i + 1) == self._num_train_batches: self._save_checkpoint(batch_i) # write statistics if (batch_i % self._summary_interval == 0) or ( batch_i + 1) == self._num_train_batches: if self._should_log_parameter_statistics: self._parameter_and_gradient_statistics() train_metrics = self._metrics(self._last_train_batch) for k, v in train_metrics.items(): logger.logkv("training/{}".format(k), v) val_metrics = self._metrics(self._valid_batch, log_sample_MM_spectrum=3) for k, v in val_metrics.items(): logger.logkv("validation/{}".format(k), v) logger.dumpkvs() return metrics
def train(self): logger.info("Begin training...") try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise Exception( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing log " "directory?") train_metrics = {} valid_metrics = {} metrics = {} this_epoch_valid_metric: float = None epochs_trained = 0 training_start_time = time.time() metrics['best_epoch'] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation/" + key] = value for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() with utils.Timer() as tr_dt: train_metrics = self._train_epoch(epoch) train_metrics['epoch_time'] = tr_dt.dt # get peak of memory usage if 'cpu_memory_MB' in train_metrics: metrics['peak_cpu_memory_MB'] = max( metrics.get('peak_cpu_memory_MB', 0), train_metrics['cpu_memory_MB']) if self._validation_dataset is not None: with utils.Timer() as val_dt: valid_metrics = self._validation() this_epoch_valid_metric = valid_metrics['loss/mean'] self._metric_tracker.add_metric(this_epoch_valid_metric) if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break valid_metrics['epoch_time'] = val_dt.dt training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = time.strftime( "%H:%M:%S", time.gmtime(training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for k, v in train_metrics.items(): logger.logkv("training/{}".format(k), v) for k, v in valid_metrics.items(): logger.logkv("validation/{}".format(k), v) if self._logdir: if (epochs_trained % self._ckpt_interval == 0) or (epochs_trained + 1) == self._num_epochs: self._save_checkpoint(epoch) if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics['best_epoch'] = epoch for key, value in valid_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = valid_metrics if self._learning_rate_scheduler: self._learning_rate_scheduler.step() if (epochs_trained == 0) or epochs_trained % self._summary_interval == 0: if self._should_log_parameter_statistics: self._parameter_and_gradient_statistics() train_metrics_ = self._metrics(self._train_dataset) for k, v in train_metrics_.items(): logger.logkv("training/{}".format(k), v) val_metrics_ = self._metrics(self._validation_dataset) for k, v in val_metrics_.items(): logger.logkv("validation/{}".format(k), v) if self._log_viz: try: fig_map = self._viz_func(self.model) for k, fig in fig_map.items(): logger.add_figure(k, fig) except Exception as e: logger.info("Couldn't log viz: {}".format(e)) epoch_elapsed_time = time.time() - epoch_start_time logger.info( "Epoch duration: %s", time.strftime("%H:%M:%S", time.gmtime(epoch_elapsed_time))) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * \ ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) logger.dumpkvs() epochs_trained += 1 return metrics