def valid_epoch(self,
                    data_loader,
                    model=None,
                    gen_output=False,
                    out_dir=None):
        r"""Evaluates model once on all batches in data loader given. Performs analysis of model output if requested.

        Parameters
        ----------
        data_loader : torch.utils.data.DataLoader (in a :class:`data._DataLoaderWrapper` container)
            An instance with the `__iter__` method, allowing for iteration over batches of the dataset.
        model : morgana.base_models.BaseModel
            Model instance. If `self.ema_decay` is non-zero this will be the ema model.
        gen_output : bool
            Whether to generate output for this validation epoch. Output is defined by
            :func:`morgana.base_models.BaseModel.analysis_for_valid_batch`.
        out_dir : str
            Directory used to save output (changes for each epoch).

        Returns
        -------
        loss : float
            Average loss for entire batch.
        """
        if model is None:
            model = self.model

        model.mode = 'valid'
        model.metrics.reset_state('valid')

        loss = 0.0
        pbar = _logging.ProgressBar(len(data_loader))
        for i, (features, names) in zip(pbar, data_loader):
            self.model.step = (self.epoch - 1) * len(data_loader) + i + 1

            batch_loss, output_features = model(features)

            loss += batch_loss.item()

            # Log metrics.
            pbar.print('valid',
                       self.epoch,
                       batch_loss=tqdm.format_num(batch_loss),
                       **model.metrics.results_as_str_dict('valid'))

            if gen_output:
                model.analysis_for_valid_batch(features,
                                               output_features,
                                               names,
                                               out_dir=out_dir,
                                               sample_rate=self.sample_rate)

        if out_dir:
            os.makedirs(out_dir, exist_ok=True)
            file_io.save_json(model.metrics.results_as_json_dict('valid'),
                              os.path.join(out_dir, 'metrics.json'))

        model.mode = ''

        return loss / (i + 1)
Example #2
0
    def test_epoch(self, data_loader, model=None, out_dir=None):
        r"""Evaluates the model once on all batches in the data loader given. Performs analysis of model predictions.

        Parameters
        ----------
        data_loader : :class:`torch.utils.data.DataLoader` (in a :class:`morgana.data._DataLoaderWrapper` container)
            An instance with the `__iter__` method, allowing for iteration over batches of the dataset.
        model : morgana.base_models.BaseModel
            Model instance. If `self.ema_decay` is non-zero this will be the ema model.
        out_dir : str
            Directory used to save output (changes for each epoch).
        """
        if model is None:
            model = self.model

        if out_dir:
            os.makedirs(out_dir, exist_ok=True)

        model.mode = 'test'
        model.metrics.reset_state('test')

        pbar = _logging.ProgressBar(len(data_loader))
        for i, features in zip(pbar, data_loader):
            self.model.step = (self.epoch - 1) * len(data_loader) + i + 1

            output_features = model.predict(features)

            model.analysis_for_test_batch(features,
                                          output_features,
                                          out_dir=out_dir,
                                          **self.analysis_kwargs)

            # Log metrics.
            pbar.print('test', self.epoch,
                       **model.metrics.results_as_str_dict('test'))

        model.analysis_for_test_epoch(out_dir=out_dir, **self.analysis_kwargs)

        if out_dir:
            file_io.save_json(model.metrics.results_as_json_dict('test'),
                              os.path.join(out_dir, 'metrics.json'))

        model.mode = ''
    def train_epoch(self,
                    data_loader,
                    optimizer,
                    lr_schedule=None,
                    gen_output=False,
                    out_dir=None):
        r"""Trains the model once on all batches in the data loader given.

        * Gradient updates, and EMA gradient updates.
        * Batch level learning rate schedule updates.
        * Logging metrics to tqdm and to a `metrics.json` file.

        Parameters
        ----------
        data_loader : torch.utils.data.DataLoader (in a :class:`data._DataLoaderWrapper` container)
            An instance with the `__iter__` method, allowing for iteration over batches of the dataset.
        optimizer : torch.optim.Optimizer
        lr_schedule : torch.optim.lr_scheduler._LRScheduler
            Learning rate schedule, only used if it is a member of `morgana.lr_schedules.BATCH_LR_SCHEDULES`.
        gen_output : bool
            Whether to generate output for this training epoch. Output is defined by
            :func:`morgana.base_models.BaseModel.analysis_for_train_batch`.
        out_dir : str
            Directory used to save output (changes for each epoch).

        Returns
        -------
        loss : float
            Average loss for entire batch.
        """
        self.model.mode = 'train'
        self.model.metrics.reset_state('train')

        loss = 0.0
        pbar = _logging.ProgressBar(len(data_loader))
        for i, (features, names) in zip(pbar, data_loader):
            self.model.step = (self.epoch - 1) * len(data_loader) + i + 1

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            batch_loss, output_features = self.model(features)

            batch_loss.backward()
            optimizer.step()

            # Update the learning rate.
            if lr_schedule is not None and self.lr_schedule_name in lr_schedules.BATCH_LR_SCHEDULES:
                lr_schedule.step()

            loss += batch_loss.item()

            # Update the exponential moving average model if it exists.
            if self.ema_decay:
                self.ema.update_params(self.model)

            # Log metrics.
            pbar.print('train',
                       self.epoch,
                       batch_loss=tqdm.format_num(batch_loss),
                       **self.model.metrics.results_as_str_dict('train'))

            if gen_output:
                self.model.analysis_for_train_batch(
                    features,
                    output_features,
                    names,
                    out_dir=out_dir,
                    sample_rate=self.sample_rate)

        if out_dir:
            os.makedirs(out_dir, exist_ok=True)
            file_io.save_json(self.model.metrics.results_as_json_dict('train'),
                              os.path.join(out_dir, 'metrics.json'))

        self.model.mode = ''

        return loss / (i + 1)
Example #4
0
    def train_epoch(self,
                    data_generator,
                    optimizer,
                    lr_schedule=None,
                    gen_output=False,
                    out_dir=None):
        self.model.mode = 'train'
        self.model.metrics.reset_state('train')

        loss = 0.0
        pbar = _logging.ProgressBar(len(data_generator))
        for i, features in zip(pbar, data_generator):
            self.model.step = (self.epoch - 1) * len(data_generator) + i + 1

            # Anneal the KL divergence, linearly increasing from 0.0 to the initial KLD weight set in the model.
            if self.kld_wait_epochs != 0 and self.epoch == self.kld_wait_epochs + 1 and self.kld_warmup_epochs == 0:
                self.model.kld_weight = self.model.max_kld_weight
            if self.kld_warmup_epochs != 0 and self.epoch > self.kld_wait_epochs:
                if self.model.kld_weight < self.model.max_kld_weight:
                    self.model.kld_weight += self.model.max_kld_weight / (
                        self.kld_warmup_epochs * len(data_generator))
                    self.model.kld_weight = min(self.model.max_kld_weight,
                                                self.model.kld_weight)

            self.model.tensorboard.add_scalar('kl_weight',
                                              self.model.kld_weight,
                                              global_step=self.model.step)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            batch_loss, output_features = self.model(features)

            batch_loss.backward()
            optimizer.step()

            # Update the learning rate.
            if lr_schedule is not None and self.lr_schedule_name in lr_schedules.BATCH_LR_SCHEDULES:
                lr_schedule.step()

            loss += batch_loss.item()

            # Update the exponential moving average model if it exists.
            if self.ema_decay:
                self.ema.update_params(self.model)

            # Log metrics.
            pbar.print('train',
                       self.epoch,
                       kld_weight=tqdm.format_num(self.model.kld_weight),
                       batch_loss=tqdm.format_num(batch_loss),
                       **self.model.metrics.results_as_str_dict('train'))

            if gen_output:
                self.model.analysis_for_train_batch(features,
                                                    output_features,
                                                    out_dir=out_dir,
                                                    **self.analysis_kwargs)

        if gen_output:
            self.model.analysis_for_train_epoch(out_dir=out_dir,
                                                **self.analysis_kwargs)

        if out_dir:
            os.makedirs(out_dir, exist_ok=True)
            file_io.save_json(self.model.metrics.results_as_json_dict('train'),
                              os.path.join(out_dir, 'metrics.json'))

        self.model.mode = ''

        return loss / (i + 1)