Example #1
0
 def train_log_str(cls,
                   summaries: dict,
                   step: int,
                   epoch: Optional[int] = None) -> str:
     """Returns log string for training metrics."""
     if get_option('show_progress_bar'):
         s = "       "
     else:
         s = "train: "
     s += "[step {}]".format(step)
     for k in summaries:
         s += "  {key}={value:.5g}".format(key=k, value=summaries[k])
     return s
Example #2
0
def print_num_params(model: nn.Module, max_depth: Optional[int] = __sentinel):
    """Prints overview of model architecture with number of parameters.

    Optionally, it groups together parameters below a certain depth in the
    module tree. The depth defaults to packagewide options.

    Args:
        model (torch.nn.Module)
        max_depth (int, optional)
    """

    if max_depth is __sentinel:
        max_depth = get_option('model_print_depth')
    assert max_depth is None or isinstance(max_depth, int)

    sep = '.'  # string separator in parameter name
    print("\n--- Trainable parameters:")
    num_params_tot = 0
    num_params_dict = OrderedDict()

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        num_params = param.numel()

        if max_depth is not None:
            split = name.split(sep)
            prefix = sep.join(split[:max_depth])
        else:
            prefix = name
        if prefix not in num_params_dict:
            num_params_dict[prefix] = 0
        num_params_dict[prefix] += num_params
        num_params_tot += num_params
    for n, n_par in num_params_dict.items():
        print("{:8d}  {}".format(n_par, n))
    print("  - Total trainable parameters:", num_params_tot)
    print("---------\n")
Example #3
0
    def run(self):
        """Runs the trainer."""

        # This is needed when loading to resume training
        first_step = True

        # Setup
        e = self.experiment
        train_loader = e.dataloaders.train
        progress = None
        show_progress = get_option('show_progress_bar')
        train_summarizers = SummarizerCollection(
            mode='moving_average',
            ma_length=get_option('train_summarizer_ma_length'))
        # Additional summarizers are considered independent of the train/test
        # regime, they are not printed, they are saved to tensorboard only once
        # (during training and not testing), and for now they are not saved to
        # the history.
        additional_summarizers = SummarizerCollection(
            mode='moving_average',
            ma_length=get_option('train_summarizer_ma_length'))

        # Training mode
        e.model.train()

        # Main loop
        for epoch in range(1, e.args.max_epochs + 1):
            for batch_idx, (x, y) in enumerate(train_loader):

                step = e.model.global_step

                if step >= e.args.max_steps:
                    break

                if step % e.args.test_log_every == 0:

                    # Test model (unless we just resumed training)
                    if not first_step or step == 0:
                        with torch.no_grad():
                            self._test(epoch)

                    # Save model checkpoint (unless we just started/resuming training)
                    if not first_step and step % e.args.checkpoint_every == 0:
                        print("* saving model checkpoint at "
                              "step {}".format(step))
                        e.model.checkpoint(self.checkpoint_folder,
                                           e.args.keep_checkpoint_max)

                    # (Re)start progress bar
                    if show_progress:
                        progress = tqdm(total=e.args.test_log_every,
                                        desc='train')

                    # Restart timer to measure training speed
                    timer_start = timeit.default_timer()
                    steps_start = e.model.global_step

                    # This timer stuff won't make sense if 'test every' is not
                    # a multiple of 'train every'. Which is now true, but let's
                    # be safe in case things change
                    assert e.args.test_log_every % e.args.test_log_every == 0

                # Reset gradients
                e.optimizer.zero_grad()

                # Forward pass: get loss and other info
                outputs = e.forward_pass(x, y)

                # Compute gradients (backward pass)
                outputs['loss'].backward()

                e.post_backward_callback()

                if e.args.max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(
                        e.model.parameters(), max_norm=e.args.max_grad_norm)

                # Add batch metrics to summarizers
                metrics_dict = e.get_metrics_dict(outputs)
                train_summarizers.add(metrics_dict)

                # Compute gradient stats and add to summarizers
                # - grad norm of each parameter
                # - grad norm of given group (or default groups)
                # - total grad norm
                # TODO groups
                grad_stats = {
                    'grad_norm_total/grad_norm_total':
                        grad_global_norm(e.model.parameters())
                }
                for n, p in e.model.named_parameters():
                    k = 'grad_norm_per_parameter/' + n
                    grad_stats[k] = grad_global_norm(p)
                additional_summarizers.add(grad_stats)

                # Compute L2 norm of parameters and add it to summarizers
                additional_summarizers.add(
                    {'L2_norm': global_norm(e.model.parameters())})

                # Update progress bar
                if show_progress:
                    progress.update()

                # Close progress bar if test occurs at next loop iteration
                if (step + 1) % e.args.test_log_every == 0:
                    # If show_progress is False, progress is None
                    if progress is not None:
                        progress.close()

                if (step + 1) % e.args.train_log_every == 0:
                    # step+1 because we already did a forward/backward step

                    # Get moving average of metrics and reset summarizers
                    train_summaries = train_summarizers.get_all(reset=True)
                    additional_summaries = additional_summarizers.get_all(
                        reset=True)

                    # Get training speed and add it to summaries
                    elapsed = timeit.default_timer() - timer_start
                    iterations = e.model.global_step - steps_start
                    steps_per_sec = iterations / elapsed
                    ex_per_sec = steps_per_sec * e.args.batch_size
                    additional_summaries['speed/steps_per_sec'] = steps_per_sec
                    additional_summaries['speed/examples_per_sec'] = ex_per_sec
                    timer_start = timeit.default_timer()
                    steps_start = e.model.global_step

                    # Print summaries
                    print(e.train_log_str(train_summaries, step + 1, epoch))

                    # Add train summaries (smoothed) to history and dump it to
                    # file and to tensorboard if available
                    self.train_history.add(train_summaries, step + 1)
                    if not e.args.dry_run:
                        with open(self.log_path, 'wb') as fd:
                            pickle.dump(self._history_dict(), fd)
                        if self.tb_writer is not None:
                            for k, v in train_summaries.items():
                                self.tb_writer.add_scalar(
                                    'train_' + k, v, step + 1)
                            for k, v in additional_summaries.items():
                                self.tb_writer.add_scalar(k, v, step + 1)

                # Optimization step
                e.optimizer.step()

                # Increment model's global step variable
                e.model.increment_global_step()

                first_step = False

            if step >= e.args.max_steps:
                break

        if progress is not None:  # if show_progress is False, progress is None
            progress.close()
        if step < e.args.max_steps:
            print("Reached epochs limit ({} epochs)".format(e.args.max_epochs))
        else:
            print("Reached steps limit ({} steps)".format(e.args.max_steps))
Example #4
0
    def test_procedure(self, iw_samples: Optional[int] = None):
        """Executes the experiment's test procedure and returns results.

        Collects variational inference metrics on the test set using
        forward_pass(), and repeat to derive the importance-weighted ELBO.

        Args:
            iw_samples (int, optional): number of samples for the importance-
                weighted ELBO. The other metrics are also averaged over all
                these samples, yielding a more accurate estimate.

        Returns:
            summaries (dict)
        """

        # Shorthand
        test_loader = self.dataloaders.test
        step = self.model.global_step
        args = self.args
        n_test = len(test_loader.dataset)

        # If it's time to estimate log likelihood, use many samples.
        # If given, use the given number.
        if iw_samples is None:
            iw_samples = 1
            if step % args.loglikelihood_every == 0 and step > 0:
                iw_samples = args.loglikelihood_samples

        # Setup
        summarizers = SummarizerCollection(mode='sum')
        show_progress = get_option('show_progress_bar')
        if show_progress:
            progress = tqdm(total=len(test_loader) * iw_samples, desc='test ')
        all_elbo_sep = torch.zeros(n_test, iw_samples)

        # Do test
        for batch_idx, (x, y) in enumerate(test_loader):
            for i in range(iw_samples):
                outputs = self.forward_pass(x, y)

                # elbo_sep shape (batch size,)
                i_start = batch_idx * args.test_batch_size
                i_end = (batch_idx + 1) * args.test_batch_size
                all_elbo_sep[i_start:i_end, i] = outputs['elbo_sep'].detach()

                metrics_dict = self.get_metrics_dict(outputs)
                multiplier = (x.size(0) / n_test) / iw_samples
                for k in metrics_dict:
                    metrics_dict[k] *= multiplier
                summarizers.add(metrics_dict)

                if show_progress:
                    progress.update()

        if show_progress:
            progress.close()

        if iw_samples > 1:
            # Shape (test set size,)
            elbo_iw = torch.logsumexp(all_elbo_sep, dim=1)
            elbo_iw = elbo_iw - np.log(iw_samples)

            # Mean over test set (scalar)
            elbo_iw = elbo_iw.mean().item()
            key = 'elbo/elbo_IW_{}'.format(iw_samples)
            summarizers.add({key: elbo_iw})

        summaries = summarizers.get_all(reset=True)
        return summaries