def valid_epoch(self, data_loader, model=None, gen_output=False, out_dir=None): r"""Evaluates model once on all batches in data loader given. Performs analysis of model output if requested. Parameters ---------- data_loader : torch.utils.data.DataLoader (in a :class:`data._DataLoaderWrapper` container) An instance with the `__iter__` method, allowing for iteration over batches of the dataset. model : morgana.base_models.BaseModel Model instance. If `self.ema_decay` is non-zero this will be the ema model. gen_output : bool Whether to generate output for this validation epoch. Output is defined by :func:`morgana.base_models.BaseModel.analysis_for_valid_batch`. out_dir : str Directory used to save output (changes for each epoch). Returns ------- loss : float Average loss for entire batch. """ if model is None: model = self.model model.mode = 'valid' model.metrics.reset_state('valid') loss = 0.0 pbar = _logging.ProgressBar(len(data_loader)) for i, (features, names) in zip(pbar, data_loader): self.model.step = (self.epoch - 1) * len(data_loader) + i + 1 batch_loss, output_features = model(features) loss += batch_loss.item() # Log metrics. pbar.print('valid', self.epoch, batch_loss=tqdm.format_num(batch_loss), **model.metrics.results_as_str_dict('valid')) if gen_output: model.analysis_for_valid_batch(features, output_features, names, out_dir=out_dir, sample_rate=self.sample_rate) if out_dir: os.makedirs(out_dir, exist_ok=True) file_io.save_json(model.metrics.results_as_json_dict('valid'), os.path.join(out_dir, 'metrics.json')) model.mode = '' return loss / (i + 1)
def test_epoch(self, data_loader, model=None, out_dir=None): r"""Evaluates the model once on all batches in the data loader given. Performs analysis of model predictions. Parameters ---------- data_loader : :class:`torch.utils.data.DataLoader` (in a :class:`morgana.data._DataLoaderWrapper` container) An instance with the `__iter__` method, allowing for iteration over batches of the dataset. model : morgana.base_models.BaseModel Model instance. If `self.ema_decay` is non-zero this will be the ema model. out_dir : str Directory used to save output (changes for each epoch). """ if model is None: model = self.model if out_dir: os.makedirs(out_dir, exist_ok=True) model.mode = 'test' model.metrics.reset_state('test') pbar = _logging.ProgressBar(len(data_loader)) for i, features in zip(pbar, data_loader): self.model.step = (self.epoch - 1) * len(data_loader) + i + 1 output_features = model.predict(features) model.analysis_for_test_batch(features, output_features, out_dir=out_dir, **self.analysis_kwargs) # Log metrics. pbar.print('test', self.epoch, **model.metrics.results_as_str_dict('test')) model.analysis_for_test_epoch(out_dir=out_dir, **self.analysis_kwargs) if out_dir: file_io.save_json(model.metrics.results_as_json_dict('test'), os.path.join(out_dir, 'metrics.json')) model.mode = ''
def train_epoch(self, data_loader, optimizer, lr_schedule=None, gen_output=False, out_dir=None): r"""Trains the model once on all batches in the data loader given. * Gradient updates, and EMA gradient updates. * Batch level learning rate schedule updates. * Logging metrics to tqdm and to a `metrics.json` file. Parameters ---------- data_loader : torch.utils.data.DataLoader (in a :class:`data._DataLoaderWrapper` container) An instance with the `__iter__` method, allowing for iteration over batches of the dataset. optimizer : torch.optim.Optimizer lr_schedule : torch.optim.lr_scheduler._LRScheduler Learning rate schedule, only used if it is a member of `morgana.lr_schedules.BATCH_LR_SCHEDULES`. gen_output : bool Whether to generate output for this training epoch. Output is defined by :func:`morgana.base_models.BaseModel.analysis_for_train_batch`. out_dir : str Directory used to save output (changes for each epoch). Returns ------- loss : float Average loss for entire batch. """ self.model.mode = 'train' self.model.metrics.reset_state('train') loss = 0.0 pbar = _logging.ProgressBar(len(data_loader)) for i, (features, names) in zip(pbar, data_loader): self.model.step = (self.epoch - 1) * len(data_loader) + i + 1 # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize batch_loss, output_features = self.model(features) batch_loss.backward() optimizer.step() # Update the learning rate. if lr_schedule is not None and self.lr_schedule_name in lr_schedules.BATCH_LR_SCHEDULES: lr_schedule.step() loss += batch_loss.item() # Update the exponential moving average model if it exists. if self.ema_decay: self.ema.update_params(self.model) # Log metrics. pbar.print('train', self.epoch, batch_loss=tqdm.format_num(batch_loss), **self.model.metrics.results_as_str_dict('train')) if gen_output: self.model.analysis_for_train_batch( features, output_features, names, out_dir=out_dir, sample_rate=self.sample_rate) if out_dir: os.makedirs(out_dir, exist_ok=True) file_io.save_json(self.model.metrics.results_as_json_dict('train'), os.path.join(out_dir, 'metrics.json')) self.model.mode = '' return loss / (i + 1)
def train_epoch(self, data_generator, optimizer, lr_schedule=None, gen_output=False, out_dir=None): self.model.mode = 'train' self.model.metrics.reset_state('train') loss = 0.0 pbar = _logging.ProgressBar(len(data_generator)) for i, features in zip(pbar, data_generator): self.model.step = (self.epoch - 1) * len(data_generator) + i + 1 # Anneal the KL divergence, linearly increasing from 0.0 to the initial KLD weight set in the model. if self.kld_wait_epochs != 0 and self.epoch == self.kld_wait_epochs + 1 and self.kld_warmup_epochs == 0: self.model.kld_weight = self.model.max_kld_weight if self.kld_warmup_epochs != 0 and self.epoch > self.kld_wait_epochs: if self.model.kld_weight < self.model.max_kld_weight: self.model.kld_weight += self.model.max_kld_weight / ( self.kld_warmup_epochs * len(data_generator)) self.model.kld_weight = min(self.model.max_kld_weight, self.model.kld_weight) self.model.tensorboard.add_scalar('kl_weight', self.model.kld_weight, global_step=self.model.step) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize batch_loss, output_features = self.model(features) batch_loss.backward() optimizer.step() # Update the learning rate. if lr_schedule is not None and self.lr_schedule_name in lr_schedules.BATCH_LR_SCHEDULES: lr_schedule.step() loss += batch_loss.item() # Update the exponential moving average model if it exists. if self.ema_decay: self.ema.update_params(self.model) # Log metrics. pbar.print('train', self.epoch, kld_weight=tqdm.format_num(self.model.kld_weight), batch_loss=tqdm.format_num(batch_loss), **self.model.metrics.results_as_str_dict('train')) if gen_output: self.model.analysis_for_train_batch(features, output_features, out_dir=out_dir, **self.analysis_kwargs) if gen_output: self.model.analysis_for_train_epoch(out_dir=out_dir, **self.analysis_kwargs) if out_dir: os.makedirs(out_dir, exist_ok=True) file_io.save_json(self.model.metrics.results_as_json_dict('train'), os.path.join(out_dir, 'metrics.json')) self.model.mode = '' return loss / (i + 1)