def train(self): assert "training" in self.data_loaders # set cumulative losses to zero self.losses = self.training_process.losses['training'].to( self.training_process.config.device) self.cumulative_losses = { loss_name: 0 for loss_name in self.losses.names } if self.training_process.interpolator is not None: self.cumulative_losses.update( {'itp_' + loss_name: 0 for loss_name in self.losses.names}) # set mode of network to train (parameters will be affected / updated) self.training_process.model.train() # set progress bar progressBar = ProgressBar(self.num_loadings['training'], displaySumCount=True) # iterate over data loader for i, batch in enumerate(self.data_loaders['training']): # load data to device (targets_device, inputs_device, mask_lr_device, mask_hr_device, offset_lr, offset_hr) = self._prepare_data(batch) # zero out optimizer weights self.training_process.optimizer.zero_grad() # apply model predictions_device, interpolate = self._apply_model( inputs_device, offset_lr, offset_hr) # compute loss to current targets loss = self._update_losses(predictions_device, targets_device, interpolate, mask_hr_device, offset_hr) # backpropagate loss loss.backward() # advance weights towards minimum self.training_process.optimizer.step() # update progress bar progressBar.proceed(i + 1) # update epoch count self._update_summary('training') # save epoch state if self.training_process.config.saving_period is not None: self._save_state() # update scheduler self._update_scheduler()
def validate(self): assert "validation" in self.data_loaders # set cumulative losses to 0 if "validation" in self.training_process.losses: self.losses = self.training_process.losses['validation'].to( self.training_process.config.device) else: self.losses = self.training_process.losses['training'].to( self.training_process.config.device) self.cumulative_losses = { loss_name: 0 for loss_name in self.losses.names } if self.training_process.interpolator is not None: self.cumulative_losses.update( {'itp_' + loss_name: 0 for loss_name in self.losses.names}) # set mode of network to eval (parameters will not be affected / updated) self.training_process.model.eval() # set progress bar progressBar = ProgressBar(self.num_loadings['validation'], displaySumCount=True) # iterate over data loader with torch.no_grad(): for i, batch in enumerate(self.data_loaders['validation']): # load data to device (targets_device, inputs_device, mask_lr_device, mask_hr_device, offset_lr, offset_hr) = self._prepare_data(batch) # apply model predictions_device, interpolate = self._apply_model( inputs_device, offset_lr, offset_hr) # compute loss to current targets self._update_losses(predictions_device, targets_device, interpolate, mask_hr_device, offset_hr) # update progress bar progressBar.proceed(i + 1) # update epoch count self._update_summary('validation')