def _log_basic(self, stats, misc): """Log to stdout and text log file""" tr_loss = np.mean(stats['tr_loss']) val_loss = np.mean(stats['val_loss']) lr = misc['learning_rate'] tr_speed = misc['tr_speed'] tr_speed_vx = misc['tr_speed_vx'] t = pretty_string_time(self._timer.t_passed) text = f'step={self.step:06d}, tr_loss={tr_loss:.3f}, val_loss={val_loss:.3f}, ' text += f'lr={lr:.2e}, {tr_speed:.2f} it/s, {tr_speed_vx:.2f} MVx/s, {t}' logger.info(text)
def train(self, max_steps: int = 1, max_runtime=3600 * 24 * 7) -> None: """Train the network for ``max_steps`` steps. After each training epoch, validation performance is measured and visualizations are computed and logged to tensorboard.""" self.start_time = datetime.datetime.now() self.end_time = self.start_time + datetime.timedelta(seconds=max_runtime) while not self.terminate: try: # --> self.train() self.model.train() # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, float] = {'tr_loss': 0.0} # Other scalars to be logged misc: Dict[str, float] = {} # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, torch.Tensor] = {} running_acc = 0 running_mean_target = 0 running_vx_size = 0 timer = Timer() for inp, target in self.train_loader: inp, target = inp.to(self.device), target.to(self.device) # forward pass out = self.model(inp) loss = self.criterion(out, target) if torch.isnan(loss): logger.error('NaN loss detected! Aborting training.') raise NaNException # update step self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Prevent accidental autograd overheads after optimizer step inp.detach_() target.detach_() out.detach_() loss.detach_() # get training performance stats['tr_loss'] += float(loss) acc = metrics.bin_accuracy(target, out) # TODO mean_target = target.to(torch.float32).mean() print(f'{self.step:6d}, loss: {loss:.4f}', end='\r') self._tracker.update_timeline([self._timer.t_passed, float(loss), mean_target]) # Preserve training batch and network output for later visualization images['inp'] = inp images['target'] = target images['out'] = out # this was changed to support ReduceLROnPlateau which does not implement get_lr misc['learning_rate'] = self.optimizer.param_groups[0]["lr"] # .get_lr()[-1] # update schedules for sched in self.schedulers.values(): # support ReduceLROnPlateau; doc. uses validation loss instead # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau if "metrics" in inspect.signature(sched.step).parameters: sched.step(metrics=float(loss)) else: sched.step() running_acc += acc running_mean_target += mean_target running_vx_size += inp.numel() self.step += 1 if self.step >= max_steps: logger.info(f'max_steps ({max_steps}) exceeded. Terminating...') self.terminate = True break if datetime.datetime.now() >= self.end_time: logger.info(f'max_runtime ({max_runtime} seconds) exceeded. Terminating...') self.terminate = True break stats['tr_accuracy'] = running_acc / len(self.train_loader) stats['tr_loss'] /= len(self.train_loader) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx mean_target = running_mean_target / len(self.train_loader) if self.valid_dataset is None: stats['val_loss'], stats['val_accuracy'] = float('nan'), float('nan') else: valid_stats = self.validate() stats.update(valid_stats) # Update history tracker (kind of made obsolete by tensorboard) # TODO: Decide what to do with this, now that most things are already in tensorboard. if self.step // len(self.train_dataset) > 1: tr_loss_gain = self._tracker.history[-1][2] - stats['tr_loss'] else: tr_loss_gain = 0 self._tracker.update_history([ self.step, self._timer.t_passed, stats['tr_loss'], stats['val_loss'], tr_loss_gain, stats['tr_accuracy'], stats['val_accuracy'], misc['learning_rate'], 0, 0 ]) # 0's correspond to mom and gradnet (?) t = pretty_string_time(self._timer.t_passed) loss_smooth = self._tracker.loss._ema # Logging to stdout, text log file text = "%05i L_m=%.3f, L=%.2f, tr_acc=%05.2f%%, " % (self.step, loss_smooth, stats['tr_loss'], stats['tr_accuracy']) text += "val_acc=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % (stats['val_accuracy'], "%", mean_target * 100, tr_loss_gain) text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % (misc['learning_rate'], misc['tr_speed'], misc['tr_speed_vx'], t) logger.info(text) # Plot tracker stats to pngs in save_path self._tracker.plot(self.save_path) # Reporting to tensorboard logger if self.tb: self.tb_log_scalars(stats, 'stats') self.tb_log_scalars(misc, 'misc') if self.previews_enabled: self.tb_log_preview() self.tb_log_sample_images(images, group='tr_samples') self.tb.writer.flush() # Save trained model state self.save_model() if stats['val_loss'] < self.best_val_loss: self.best_val_loss = stats['val_loss'] self.save_model(suffix='_best') except KeyboardInterrupt: IPython.embed(header=self._shell_info) if self.terminate: return except Exception as e: traceback.print_exc() if self.ignore_errors: # Just print the traceback and try to carry on with training. # This can go wrong in unexpected ways, so don't leave the training unattended. pass elif self.ipython_on_error: print("\nEntering Command line such that Exception can be " "further inspected by user.\n\n") IPython.embed(header=self._shell_info) if self.terminate: return else: raise e self.save_model(suffix='_final')
def run(self, max_steps: int = 1, max_runtime=3600 * 24 * 7) -> None: """Train the network for ``max_steps`` steps. After each training epoch, validation performance is measured and visualizations are computed and logged to tensorboard.""" self.start_time = datetime.datetime.now() self.end_time = self.start_time + datetime.timedelta(seconds=max_runtime) while not self.terminate: try: stats, misc, images = self._train(max_steps, max_runtime) self.epoch += 1 if self.valid_dataset is None: stats['val_loss'] = float('nan') else: valid_stats = self._validate() stats.update(valid_stats) if not 'val_accuracy' in stats: stats['val_accuracy'] = float('nan') # Update history tracker (kind of made obsolete by tensorboard) # TODO: Decide what to do with this, now that most things are already in tensorboard. if self.step // len(self.train_dataset) > 1: tr_loss_gain = self._tracker.history[-1][2] - stats['tr_loss'] else: tr_loss_gain = 0 self._tracker.update_history([ self.step, self._timer.t_passed, stats['tr_loss'], stats['val_loss'], tr_loss_gain, stats['tr_accuracy'], stats['val_accuracy'], misc['learning_rate'], 0, 0 ]) # 0's correspond to mom and gradnet (?) t = pretty_string_time(self._timer.t_passed) loss_smooth = self._tracker.loss._ema # Logging to stdout, text log file text = "%05i L_m=%.3f, L=%.2f, tr_acc=%05.2f%%, " % (self.step, loss_smooth, stats['tr_loss'], stats['tr_accuracy']) text += "val_acc=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % (stats['val_accuracy'], "%", misc['mean_target'] * 100, tr_loss_gain) text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % (misc['learning_rate'], misc['tr_speed'], misc['tr_speed_vx'], t) logger.info(text) # Plot tracker stats to pngs in save_path self._tracker.plot(self.save_path) # Reporting to tensorboard logger if self.tb: try: self._tb_log_scalars(stats, 'stats') self._tb_log_scalars(misc, 'misc') if self.preview_batch is not None: if self.epoch % self.preview_interval == 0 or self.epoch == 1: # TODO: Also save preview inference results in a (3D) HDF5 file self.preview_plotting_handler(self) self.sample_plotting_handler(self, images, group='tr_samples') except Exception: logger.exception('Error occured while logging to tensorboard:') # Save trained model state self._save_model() # TODO: Support other metrics for determining what's the "best" model? if stats['val_loss'] < self.best_val_loss: self.best_val_loss = stats['val_loss'] self._save_model(suffix='_best') except KeyboardInterrupt: if self.ipython_shell: IPython.embed(header=self._shell_info) else: break if self.terminate: break except Exception as e: logger.exception('Unhandled exception during training:') if self.ignore_errors: # Just print the traceback and try to carry on with training. # This can go wrong in unexpected ways, so don't leave the training unattended. pass elif self.ipython_shell: print("\nEntering Command line such that Exception can be " "further inspected by user.\n\n") IPython.embed(header=self._shell_info) if self.terminate: break else: raise e self._save_model(suffix='_final')
def predict_proba(self, inp: np.ndarray, bs: int = 10, verbose: bool = False): """ Args: inp: Input data, e.g. of shape [N, C, H, W] bs: batch size verbose: report inference speed Returns: """ if verbose: start = time.time() if self.normalize_func is not None: inp = self.normalize_func(inp) with torch.no_grad(): # get output shape shape if type(inp) is tuple: out = self.model(*(torch.Tensor(ii[:2]).to(torch.float32).to(self.device) for ii in inp)) n_samples = len(inp[0]) else: out = self.model(torch.Tensor(inp[:2]).to(torch.float32).to(self.device)) n_samples = len(inp) # change sample number according to input if type(out) is tuple: out = tuple(np.zeros([n_samples] + list(out[ii].shape)[1:], dtype=np.float32) for ii in range(len(out))) else: out = np.zeros([n_samples] + list(out.shape)[1:], dtype=np.float32) for ii in range(0, int(np.ceil(n_samples / bs))): low = bs * ii high = bs * (ii + 1) if type(inp) is tuple: inp_stride = tuple(torch.Tensor(ii[low:high]).to(torch.float32).to(self.device) for ii in inp) res = self.model(*inp_stride) else: inp_stride = torch.Tensor(inp[low:high]).to(torch.float32).to(self.device) res = self.model(inp_stride) if type(res) is tuple: for ii in range(len(res)): out[ii][low:high] = res[ii].detach().cpu() else: out[low:high] = res.detach().cpu() if type(inp_stride) is tuple: for el in inp_stride: el.detach_() else: inp_stride.detach_() del inp_stride del res torch.cuda.empty_cache() assert high >= n_samples, "Prediction less samples then given" \ " in input." if verbose: dtime = time.time() - start if type(inp) is tuple: inp_el = np.sum([float(np.prod(inp[kk].shape)) for kk in range(len(inp))]) else: inp_el = float(np.prod(inp.shape)) speed = inp_el / dtime / 1e6 dtime = pretty_string_time(dtime) print(f'Inference speed: {speed:.2f} MB or MPix /s, time: {dtime}.') return out
def run(self, max_steps: int = 1) -> None: """Train the network for ``max_steps`` steps. After each training epoch, validation performance is measured and visualizations are computed and logged to tensorboard.""" while self.step < max_steps: try: # --> self.train() self.model.train() # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, float] = {'tr_loss_G': .0, 'tr_loss_D': .0} # Other scalars to be logged misc: Dict[str, float] = { 'G_loss_advreg': .0, 'G_loss_tnet': .0, 'G_loss_l2': .0, 'D_loss_fake': .0, 'D_loss_real': .0 } # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, torch.Tensor] = {} running_error = 0 running_mean_target = 0 running_vx_size = 0 timer = Timer() latent_points_fake = [] latent_points_real = [] for inp in self.train_loader: # ref., pos., neg. samples if inp.size()[1] != 3: raise ValueError( "Data must not contain targets. " "Input data shape is assumed to be " "(N, 3, ch, x, y), where the first two" " images in each sample is the similar" " pair, while the third one is the " "distant one.") inp0 = Variable(inp[:, 0].to(self.device)) inp1 = Variable(inp[:, 1].to(self.device)) inp2 = Variable(inp[:, 2].to(self.device)) self.optimizer.zero_grad() # forward pass dA, dB, z0, z1, z2 = self.model(inp0, inp1, inp2) z_fake_gauss = torch.squeeze(torch.cat((z0, z1, z2), dim=1)) target = torch.FloatTensor(dA.size()).fill_(-1).to( self.device) target = Variable(target) loss = self.criterion(dA, dB, target) L_l2 = torch.mean( torch.cat((z0.norm(1, dim=1), z1.norm( 1, dim=1), z2.norm(1, dim=1)), dim=0)) misc['G_loss_l2'] += self.alpha * float(L_l2) loss = loss + self.alpha * L_l2 misc['G_loss_tnet'] += (1 - self.alpha2) * float( loss) # log actual loss if torch.isnan(loss): logger.error('NaN loss detected after {self.step} ' 'steps! Aborting training.') raise NaNException # Adversarial part to enforce latent variable distribution # to be Normal / whatever prior is used if self.alpha2 > 0: self.optimizer_discr.zero_grad() # adversarial labels valid = Variable(torch.Tensor(inp0.size()[0], 1).fill_(1.0), requires_grad=False).to(self.device) fake = Variable(torch.Tensor(inp0.shape[0], 1).fill_(0.0), requires_grad=False).to(self.device) # --- Generator / TripletNet self.model_discr.eval() # TripletNet latent space should be classified as valid L_advreg = self.criterion_discr( self.model_discr(z_fake_gauss), valid) # average adversarial reg. and triplet-loss loss = (1 - self.alpha2) * loss + self.alpha2 * L_advreg # perform generator step loss.backward() self.optimizer.step() # --- Discriminator self.model.eval() self.model_discr.train() # rebuild graph (model output) to get clean backprop. z_real_gauss = Variable( self.latent_distr(inp0.size()[0], z0.size()[-1] * 3)).to( self.device) _, _, z_fake_gauss0, z_fake_gauss1, z_fake_gauss2 = self.model( inp0, inp1, inp2) z_fake_gauss = torch.squeeze( torch.cat( (z_fake_gauss0, z_fake_gauss1, z_fake_gauss2), dim=1)) # Compute discriminator outputs and loss L_real_gauss = self.criterion_discr( self.model_discr(z_real_gauss), valid) L_fake_gauss = self.criterion_discr( self.model_discr(z_fake_gauss), fake) L_discr = 0.5 * (L_real_gauss + L_fake_gauss) L_discr.backward() # Backprop loss self.optimizer_discr.step() # Apply optimization step self.model.train() # set back to training mode # # clean and report L_discr.detach() L_advreg.detach() L_real_gauss.detach() L_fake_gauss.detach() stats['tr_loss_D'] += float(L_discr) misc['G_loss_advreg'] += self.alpha2 * float( L_advreg) # log actual part of advreg misc['D_loss_real'] += float(L_real_gauss) misc['D_loss_fake'] += float(L_fake_gauss) latent_points_real.append( z_real_gauss.detach().cpu().numpy()) else: loss.backward() self.optimizer.step() latent_points_fake.append( z_fake_gauss.detach().cpu().numpy()) # # Prevent accidental autograd overheads after optimizer step inp.detach() target.detach() dA.detach() dB.detach() z0.detach() z1.detach() z2.detach() loss.detach() L_l2.detach() # get training performance stats['tr_loss_G'] += float(loss) error = calculate_error(dA, dB) mean_target = target.to(torch.float32).mean() print(f'{self.step:6d}, loss: {loss:.4f}', end='\r') self._tracker.update_timeline( [self._timer.t_passed, float(loss), mean_target]) # Preserve training batch and network output for later visualization images['inp_ref'] = inp0.cpu().numpy() images['inp_+'] = inp1.cpu().numpy() images['inp_-'] = inp2.cpu().numpy() # this was changed to support ReduceLROnPlateau which does not implement get_lr misc['learning_rate_G'] = self.optimizer.param_groups[0][ "lr"] # .get_lr()[-1] misc[ 'learning_rate_D'] = self.optimizer_discr.param_groups[ 0]["lr"] # .get_lr()[-1] # update schedules for sched in self.schedulers.values(): # support ReduceLROnPlateau; doc. uses validation loss instead # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau if "metrics" in inspect.signature( sched.step).parameters: sched.step(metrics=float(loss)) else: sched.step() running_error += error running_mean_target += mean_target running_vx_size += inp.numel() self.step += 1 if self.step >= max_steps: break stats['tr_err_G'] = float(running_error) / len( self.train_loader) stats['tr_loss_G'] /= len(self.train_loader) stats['tr_loss_D'] /= len(self.train_loader) misc['G_loss_advreg'] /= len(self.train_loader) misc['G_loss_tnet'] /= len(self.train_loader) misc['G_loss_l2'] /= len(self.train_loader) misc['D_loss_fake'] /= len(self.train_loader) misc['D_loss_real'] /= len(self.train_loader) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc[ 'tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx mean_target = running_mean_target / len(self.train_loader) if (self.valid_dataset is None) or (1 != np.random.randint( 0, 10)): # only validate 10% of the times stats['val_loss_G'], stats['val_err_G'] = float( 'nan'), float('nan') else: stats['val_loss_G'], stats['val_err_G'] = self._validate() # TODO: Report more metrics, e.g. dice error # Update history tracker (kind of made obsolete by tensorboard) # TODO: Decide what to do with this, now that most things are already in tensorboard. if self.step // len(self.train_dataset) > 1: tr_loss_gain = self._tracker.history[-1][2] - stats[ 'tr_loss_G'] else: tr_loss_gain = 0 self._tracker.update_history([ self.step, self._timer.t_passed, stats['tr_loss_G'], stats['val_loss_G'], tr_loss_gain, stats['tr_err_G'], stats['val_err_G'], misc['learning_rate_G'], 0, 0 ]) # 0's correspond to mom and gradnet (?) t = pretty_string_time(self._timer.t_passed) loss_smooth = self._tracker.loss._ema # Logging to stdout, text log file text = "%05i L_m=%.3f, L=%.2f, tr=%05.2f%%, " % ( self.step, loss_smooth, stats['tr_loss_G'], stats['tr_err_G']) text += "vl=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % ( stats['val_err_G'], "%", mean_target * 100, tr_loss_gain) text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % ( misc['learning_rate_G'], misc['tr_speed'], misc['tr_speed_vx'], t) logger.info(text) # Plot tracker stats to pngs in save_path self._tracker.plot(self.save_path) # Reporting to tensorboard logger if self.tb: self._tb_log_scalars(stats, 'stats') self._tb_log_scalars(misc, 'misc') self.tb_log_sample_images(images, group='tr_samples') # save histrograms if len(latent_points_fake) > 0: fig, ax = plt.subplots() sns.distplot(np.concatenate(latent_points_fake).flatten()) # plt.savefig(os.path.join(self.save_path, # 'latent_fake_{}.png'.format(self.step))) fig.canvas.draw() img_data = np.array(fig.canvas.renderer._renderer) self.tb.add_figure(f'latent_distr/latent_fake', plot_image(img_data), global_step=self.step) plt.close() if len(latent_points_real) > 0: fig, ax = plt.subplots() sns.distplot(np.concatenate(latent_points_real).flatten()) # plt.savefig(os.path.join(self.save_path, # 'latent_real_{}.png'.format(self.step))) fig.canvas.draw() img_data = np.array(fig.canvas.renderer._renderer) self.tb.add_figure(f'latent_distr/latent_real', plot_image(img_data), global_step=self.step) plt.close() # grab the pixel buffer and dump it into a numpy array # Save trained model state torch.save( self.model.state_dict(), # os.path.join(self.save_path, f'model-{self.step:06d}.pth') # Saving with different file names leads to heaps of large files, os.path.join(self.save_path, 'model-checkpoint.pth')) # TODO: Also save "best" model, not only the latest one, which is often overfitted. # -> "best" in which regard? Lowest validation loss, validation error? # We can't blindly trust these metrics and may have to calculate # additional metrics (with focus on object boundary correctness). except KeyboardInterrupt: IPython.embed(header=self._shell_info) if self.terminate: return except Exception as e: traceback.print_exc() if self.ignore_errors: # Just print the traceback and try to carry on with training. # This can go wrong in unexpected ways, so don't leave the training unattended. pass elif self.ipython_shell: print("\nEntering Command line such that Exception can be " "further inspected by user.\n\n") IPython.embed(header=self._shell_info) if self.terminate: return else: raise e torch.save( self.model.state_dict(), os.path.join(self.save_path, f'model-final-{self.step:06d}.pth'))