Beispiel #1
0
    def _validation(self) -> Dict[str, float]:
        logger.info("Validating")
        val_data = self._validation_dataset
        if isinstance(val_data, dataset.ActuatedTrajectoryDataset):
            val_data = transform.odepred_transform(val_data, self._pred_horizon)

        val_generator = torch.utils.data.DataLoader(val_data, shuffle=False,
                                                    batch_size=self._batch_size)
        loss_ls = []
        loss_info_ls = []
        for qs, vs, us in tqdm.tqdm(val_generator, total=len(val_generator)):
            with torch.no_grad():
                loss, loss_info = compute_qvloss(
                    ActuatedODEWrapper(self.model),
                    torch.stack(qs).to(self._device),
                    torch.stack(vs).to(self._device),
                    torch.stack(us).to(self._device), dt=self._dt, vlambda=self._vlambda,
                    method=self._integration_method)

            loss_ls.append(loss.cpu().detach().item())
            loss_info_ls.append(loss_info)

        metrics = {}
        loss_info = nested.zip(*loss_info_ls)

        metrics['loss/mean'] = np.mean(loss_ls)
        metrics['loss/std'] = np.std(loss_ls)
        metrics['log10loss/mean'] = np.mean(np.log10(loss_ls))
        metrics['log10loss/std'] = np.std(np.log10(loss_ls))
        for k, val in loss_info.items():
            metrics['loss/{}/mean'.format(k)] = np.mean(val)
            metrics['loss/{}/std'.format(k)] = np.std(val)

        return metrics
Beispiel #2
0
    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        logger.info("Epoch {}/{}".format(epoch, self._num_epochs - 1))
        peak_mem_usage = utils.peak_memory_mb()
        logger.info(f"Peak CPU memory usage MB: {peak_mem_usage}")

        loss_ls = []
        loss_info_ls = []
        losstimer_ls = []
        gradtimer_ls = []

        train_data = self._train_dataset
        if isinstance(train_data, dataset.ActuatedTrajectoryDataset):
            train_data = transform.odepred_transform(train_data, self._pred_horizon)

        train_generator = torch.utils.data.DataLoader(train_data, shuffle=self._shuffle,
                                                      batch_size=self._batch_size)
        for qs, vs, us in tqdm.tqdm(train_generator, total=len(train_generator)):
            self.optimizer.zero_grad()
            with utils.Timer() as losstime:
                loss, loss_info = compute_qvloss(
                    ActuatedODEWrapper(self.model),
                    torch.stack(qs).to(self._device),
                    torch.stack(vs).to(self._device),
                    torch.stack(us).to(self._device), dt=self._dt, vlambda=self._vlambda,
                    method=self._integration_method)

            with utils.Timer() as gradtime:
                loss.backward()

            self.optimizer.step()

            loss_ls.append(loss.cpu().detach().numpy())
            loss_info_ls.append(loss_info)
            losstimer_ls.append(losstime.dt)
            gradtimer_ls.append(gradtime.dt)

        metrics = {}
        loss_info = nested.zip(*loss_info_ls)
        metrics['cpu_memory_MB'] = peak_mem_usage
        metrics['loss/mean'] = np.mean(loss_ls)
        metrics['loss/std'] = np.std(loss_ls)
        metrics['log10loss/mean'] = np.mean(np.log10(loss_ls))
        metrics['log10loss/std'] = np.std(np.log10(loss_ls))
        for k, val in loss_info.items():
            metrics['loss/{}/mean'.format(k)] = np.mean(val)
            metrics['loss/{}/std'.format(k)] = np.std(val)

        metrics['time/loss/mean'] = np.mean(losstimer_ls)
        metrics['time/loss/std'] = np.std(losstimer_ls)
        metrics['time/loss/max'] = np.max(losstimer_ls)
        metrics['time/loss/min'] = np.min(losstimer_ls)
        metrics['time/grad/mean'] = np.mean(gradtimer_ls)
        metrics['time/grad/std'] = np.std(gradtimer_ls)
        metrics['time/grad/max'] = np.max(gradtimer_ls)
        metrics['time/grad/min'] = np.min(gradtimer_ls)

        if self._learning_rate_scheduler:
            metrics['lr'] = self._learning_rate_scheduler.get_lr()[0]

        return metrics
Beispiel #3
0
 def _validation(self) -> Dict[str, float]:
     logger.info("Validating")
     q_b, v_b, qddot_b, F_b = map(lambda b: b.to(self._device), self._valid_batch)
     with torch.no_grad():
         loss, loss_vec = compute_DeLaNloss(self.model, q_b, v_b, qddot_b, F_b)
         
     metrics = {}
     metrics['loss_F'] = loss.cpu().detach().item()
     for i in range(self.model._qdim):
         metrics[f'loss_F_{i}'] = loss_vec[i].cpu().detach().item()
     
     return metrics
Beispiel #4
0
 def _train_batch(self, batch_i: int) -> Dict[str, float]:
     logger.info("Training batch {}/{}".format(batch_i, self._num_train_batches - 1))
     
     self._last_train_batch = next(self._train_datagen)
     self.optimizer.zero_grad()
     q_b, v_b, qddot_b, F_b = map(lambda b: b.to(self._device), self._last_train_batch)
     loss, loss_vec = compute_DeLaNloss(self.model, q_b, v_b, qddot_b, F_b)
     loss.backward()
     self.optimizer.step()
     
     metrics = {}
     metrics['loss_F_tot'] = loss.cpu().detach().item()
     for i in range(self.model._qdim):
         metrics[f'loss_F_{i}'] = loss_vec[i].cpu().detach().item()
     
     return metrics
Beispiel #5
0
    def _parameter_and_gradient_statistics(self) -> None:
        for name, param in self.model.named_parameters():
            logger.logkv("parameter_mean/" + name, param.data.mean().item())
            logger.logkv("parameter_std/" + name, param.data.std().item())
            logger.logkv("parameter_norm/" + name, param.data.norm().item())

            if param.grad is not None:
                grad_data = param.grad.data

                # skip empty gradients
                if torch.prod(torch.tensor(grad_data.shape)).item() > 0:
                    logger.logkv("gradient_mean/" + name, grad_data.mean().item())
                    logger.logkv("gradient_std/" + name, grad_data.std().item())
                    logger.logkv("gradient_norm/" + name, grad_data.norm().item())
                else:
                    logger.info("No gradient for {}, skipping.".format(name))
Beispiel #6
0
    def _save_checkpoint(self, epoch: int) -> None:
        model_path = Path(self._logdir) / "model_state_epoch_{}.th".format(epoch)
        model_state = self.model.state_dict()
        torch.save(model_state, model_path)

        training_state = {
            'epoch': epoch,
            'metric_tracker': self._metric_tracker.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }

        training_path = Path(self._logdir) / "training_state_epoch_{}.th".format(epoch)
        torch.save(training_state, training_path)

        if self._metric_tracker.is_best_so_far():
            logger.info("Best validation performance so far. Copying weights to {}/best.th".format(
                self._logdir))
            shutil.copyfile(model_path, Path(self._logdir) / "best.th")
Beispiel #7
0
    def train(self):
        logger.info("Begin training...")
        
        metrics = {}
        
        for batch_i in tqdm.tqdm(range(self._num_train_batches)):
            train_metrics = self._train_batch(batch_i)
            valid_metrics = self._validation()
            
            for k, v in train_metrics.items():
                logger.logkv("training/{}".format(k), v)

            for k, v in valid_metrics.items():
                logger.logkv("validation/{}".format(k), v)
                
            # checkpoint
            if self._logdir:
                if (batch_i % self._ckpt_interval == 0) or (
                        batch_i + 1) == self._num_train_batches:
                    self._save_checkpoint(batch_i)
                    
            # write statistics
            if (batch_i % self._summary_interval == 0) or (
                    batch_i + 1) == self._num_train_batches:
                if self._should_log_parameter_statistics:
                    self._parameter_and_gradient_statistics()

                train_metrics = self._metrics(self._last_train_batch)
                for k, v in train_metrics.items():
                    logger.logkv("training/{}".format(k), v)

                val_metrics = self._metrics(self._valid_batch, log_sample_MM_spectrum=3)
                for k, v in val_metrics.items():
                    logger.logkv("validation/{}".format(k), v)
                
            logger.dumpkvs()
            
        return metrics
Beispiel #8
0
    def train(self):
        logger.info("Begin training...")
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise Exception(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing log "
                "directory?")

        train_metrics = {}
        valid_metrics = {}
        metrics = {}
        this_epoch_valid_metric: float = None
        epochs_trained = 0
        training_start_time = time.time()

        metrics['best_epoch'] = self._metric_tracker.best_epoch

        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation/" + key] = value

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            with utils.Timer() as tr_dt:
                train_metrics = self._train_epoch(epoch)

            train_metrics['epoch_time'] = tr_dt.dt

            # get peak of memory usage
            if 'cpu_memory_MB' in train_metrics:
                metrics['peak_cpu_memory_MB'] = max(
                    metrics.get('peak_cpu_memory_MB', 0),
                    train_metrics['cpu_memory_MB'])

            if self._validation_dataset is not None:
                with utils.Timer() as val_dt:
                    valid_metrics = self._validation()
                    this_epoch_valid_metric = valid_metrics['loss/mean']

                    self._metric_tracker.add_metric(this_epoch_valid_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

                valid_metrics['epoch_time'] = val_dt.dt

            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = time.strftime(
                "%H:%M:%S", time.gmtime(training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for k, v in train_metrics.items():
                logger.logkv("training/{}".format(k), v)

            for k, v in valid_metrics.items():
                logger.logkv("validation/{}".format(k), v)

            if self._logdir:
                if (epochs_trained % self._ckpt_interval
                        == 0) or (epochs_trained + 1) == self._num_epochs:
                    self._save_checkpoint(epoch)

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics['best_epoch'] = epoch
                for key, value in valid_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = valid_metrics

            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step()

            if (epochs_trained
                    == 0) or epochs_trained % self._summary_interval == 0:
                if self._should_log_parameter_statistics:
                    self._parameter_and_gradient_statistics()

                train_metrics_ = self._metrics(self._train_dataset)
                for k, v in train_metrics_.items():
                    logger.logkv("training/{}".format(k), v)

                val_metrics_ = self._metrics(self._validation_dataset)
                for k, v in val_metrics_.items():
                    logger.logkv("validation/{}".format(k), v)

                if self._log_viz:
                    try:
                        fig_map = self._viz_func(self.model)
                        for k, fig in fig_map.items():
                            logger.add_figure(k, fig)
                    except Exception as e:
                        logger.info("Couldn't log viz: {}".format(e))

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info(
                "Epoch duration: %s",
                time.strftime("%H:%M:%S", time.gmtime(epoch_elapsed_time)))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * \
                    ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(
                    datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s",
                            formatted_time)

            logger.dumpkvs()
            epochs_trained += 1

        return metrics