def train_log_str(cls, summaries: dict, step: int, epoch: Optional[int] = None) -> str: """Returns log string for training metrics.""" if get_option('show_progress_bar'): s = " " else: s = "train: " s += "[step {}]".format(step) for k in summaries: s += " {key}={value:.5g}".format(key=k, value=summaries[k]) return s
def print_num_params(model: nn.Module, max_depth: Optional[int] = __sentinel): """Prints overview of model architecture with number of parameters. Optionally, it groups together parameters below a certain depth in the module tree. The depth defaults to packagewide options. Args: model (torch.nn.Module) max_depth (int, optional) """ if max_depth is __sentinel: max_depth = get_option('model_print_depth') assert max_depth is None or isinstance(max_depth, int) sep = '.' # string separator in parameter name print("\n--- Trainable parameters:") num_params_tot = 0 num_params_dict = OrderedDict() for name, param in model.named_parameters(): if not param.requires_grad: continue num_params = param.numel() if max_depth is not None: split = name.split(sep) prefix = sep.join(split[:max_depth]) else: prefix = name if prefix not in num_params_dict: num_params_dict[prefix] = 0 num_params_dict[prefix] += num_params num_params_tot += num_params for n, n_par in num_params_dict.items(): print("{:8d} {}".format(n_par, n)) print(" - Total trainable parameters:", num_params_tot) print("---------\n")
def run(self): """Runs the trainer.""" # This is needed when loading to resume training first_step = True # Setup e = self.experiment train_loader = e.dataloaders.train progress = None show_progress = get_option('show_progress_bar') train_summarizers = SummarizerCollection( mode='moving_average', ma_length=get_option('train_summarizer_ma_length')) # Additional summarizers are considered independent of the train/test # regime, they are not printed, they are saved to tensorboard only once # (during training and not testing), and for now they are not saved to # the history. additional_summarizers = SummarizerCollection( mode='moving_average', ma_length=get_option('train_summarizer_ma_length')) # Training mode e.model.train() # Main loop for epoch in range(1, e.args.max_epochs + 1): for batch_idx, (x, y) in enumerate(train_loader): step = e.model.global_step if step >= e.args.max_steps: break if step % e.args.test_log_every == 0: # Test model (unless we just resumed training) if not first_step or step == 0: with torch.no_grad(): self._test(epoch) # Save model checkpoint (unless we just started/resuming training) if not first_step and step % e.args.checkpoint_every == 0: print("* saving model checkpoint at " "step {}".format(step)) e.model.checkpoint(self.checkpoint_folder, e.args.keep_checkpoint_max) # (Re)start progress bar if show_progress: progress = tqdm(total=e.args.test_log_every, desc='train') # Restart timer to measure training speed timer_start = timeit.default_timer() steps_start = e.model.global_step # This timer stuff won't make sense if 'test every' is not # a multiple of 'train every'. Which is now true, but let's # be safe in case things change assert e.args.test_log_every % e.args.test_log_every == 0 # Reset gradients e.optimizer.zero_grad() # Forward pass: get loss and other info outputs = e.forward_pass(x, y) # Compute gradients (backward pass) outputs['loss'].backward() e.post_backward_callback() if e.args.max_grad_norm is not None: torch.nn.utils.clip_grad_norm_( e.model.parameters(), max_norm=e.args.max_grad_norm) # Add batch metrics to summarizers metrics_dict = e.get_metrics_dict(outputs) train_summarizers.add(metrics_dict) # Compute gradient stats and add to summarizers # - grad norm of each parameter # - grad norm of given group (or default groups) # - total grad norm # TODO groups grad_stats = { 'grad_norm_total/grad_norm_total': grad_global_norm(e.model.parameters()) } for n, p in e.model.named_parameters(): k = 'grad_norm_per_parameter/' + n grad_stats[k] = grad_global_norm(p) additional_summarizers.add(grad_stats) # Compute L2 norm of parameters and add it to summarizers additional_summarizers.add( {'L2_norm': global_norm(e.model.parameters())}) # Update progress bar if show_progress: progress.update() # Close progress bar if test occurs at next loop iteration if (step + 1) % e.args.test_log_every == 0: # If show_progress is False, progress is None if progress is not None: progress.close() if (step + 1) % e.args.train_log_every == 0: # step+1 because we already did a forward/backward step # Get moving average of metrics and reset summarizers train_summaries = train_summarizers.get_all(reset=True) additional_summaries = additional_summarizers.get_all( reset=True) # Get training speed and add it to summaries elapsed = timeit.default_timer() - timer_start iterations = e.model.global_step - steps_start steps_per_sec = iterations / elapsed ex_per_sec = steps_per_sec * e.args.batch_size additional_summaries['speed/steps_per_sec'] = steps_per_sec additional_summaries['speed/examples_per_sec'] = ex_per_sec timer_start = timeit.default_timer() steps_start = e.model.global_step # Print summaries print(e.train_log_str(train_summaries, step + 1, epoch)) # Add train summaries (smoothed) to history and dump it to # file and to tensorboard if available self.train_history.add(train_summaries, step + 1) if not e.args.dry_run: with open(self.log_path, 'wb') as fd: pickle.dump(self._history_dict(), fd) if self.tb_writer is not None: for k, v in train_summaries.items(): self.tb_writer.add_scalar( 'train_' + k, v, step + 1) for k, v in additional_summaries.items(): self.tb_writer.add_scalar(k, v, step + 1) # Optimization step e.optimizer.step() # Increment model's global step variable e.model.increment_global_step() first_step = False if step >= e.args.max_steps: break if progress is not None: # if show_progress is False, progress is None progress.close() if step < e.args.max_steps: print("Reached epochs limit ({} epochs)".format(e.args.max_epochs)) else: print("Reached steps limit ({} steps)".format(e.args.max_steps))
def test_procedure(self, iw_samples: Optional[int] = None): """Executes the experiment's test procedure and returns results. Collects variational inference metrics on the test set using forward_pass(), and repeat to derive the importance-weighted ELBO. Args: iw_samples (int, optional): number of samples for the importance- weighted ELBO. The other metrics are also averaged over all these samples, yielding a more accurate estimate. Returns: summaries (dict) """ # Shorthand test_loader = self.dataloaders.test step = self.model.global_step args = self.args n_test = len(test_loader.dataset) # If it's time to estimate log likelihood, use many samples. # If given, use the given number. if iw_samples is None: iw_samples = 1 if step % args.loglikelihood_every == 0 and step > 0: iw_samples = args.loglikelihood_samples # Setup summarizers = SummarizerCollection(mode='sum') show_progress = get_option('show_progress_bar') if show_progress: progress = tqdm(total=len(test_loader) * iw_samples, desc='test ') all_elbo_sep = torch.zeros(n_test, iw_samples) # Do test for batch_idx, (x, y) in enumerate(test_loader): for i in range(iw_samples): outputs = self.forward_pass(x, y) # elbo_sep shape (batch size,) i_start = batch_idx * args.test_batch_size i_end = (batch_idx + 1) * args.test_batch_size all_elbo_sep[i_start:i_end, i] = outputs['elbo_sep'].detach() metrics_dict = self.get_metrics_dict(outputs) multiplier = (x.size(0) / n_test) / iw_samples for k in metrics_dict: metrics_dict[k] *= multiplier summarizers.add(metrics_dict) if show_progress: progress.update() if show_progress: progress.close() if iw_samples > 1: # Shape (test set size,) elbo_iw = torch.logsumexp(all_elbo_sep, dim=1) elbo_iw = elbo_iw - np.log(iw_samples) # Mean over test set (scalar) elbo_iw = elbo_iw.mean().item() key = 'elbo/elbo_IW_{}'.format(iw_samples) summarizers.add({key: elbo_iw}) summaries = summarizers.get_all(reset=True) return summaries