def _on_epoch_end( self, epoch, save_epochs, save_path, logging, loss, inp, tar, pred, v_loss, v_inp, v_tar, v_pred, val_data, ): self._print_info() logging = logging.append( { "loss": loss.item(), "val_loss": v_loss.item(), "rel_l2_error": l2_error(pred, tar, relative=True, squared=False)[0].item(), "val_rel_l2_error": l2_error(v_pred, v_tar, relative=True, squared=False)[0].item(), }, ignore_index=True, sort=False, ) print(logging.tail(1)) if (epoch + 1) % save_epochs == 0: fig = self._create_figure(logging, loss, inp, tar, pred, v_loss, v_inp, v_tar, v_pred) os.makedirs(save_path, exist_ok=True) torch.save( self.state_dict(), os.path.join(save_path, "model_weights_epoch{}.pt".format(epoch + 1)), ) logging.to_pickle( os.path.join(save_path, "losses_epoch{}.pkl".format(epoch + 1)), ) if fig is not None: fig.savefig( os.path.join(save_path, "plot_epoch{}.png".format(epoch + 1)), bbox_inches="tight", ) return logging
def grid_search(x, y, rec_func, grid): """ Grid search utility for tuning hyper-parameters. """ err_min = np.inf grid_param = None grid_shape = [len(val) for val in grid.values()] err = torch.zeros(grid_shape) for grid_val, nidx in zip(itertools.product(*grid.values()), np.ndindex(*grid_shape)): grid_param_cur = dict(zip(grid.keys(), grid_val)) print("Current grid parameters (" + str([cidx + 1 for cidx in nidx]) + " / " + str(grid_shape) + "): " + str(grid_param_cur)) x_rec = rec_func(y, **grid_param_cur) err[nidx], _ = l2_error(x_rec, x, relative=True, squared=False) print("Rel. recovery error: {:1.2e}".format(err[nidx]), flush=True) if err[nidx] < err_min: grid_param = grid_param_cur err_min = err[nidx] return grid_param, err_min, err
def grid_search(x, y, rec_func, grid): """ Grid search utility for tuning hyper-parameters. """ err_min = np.inf grid_param = None grid_shape = [len(val) for val in grid.values()] err = torch.zeros(grid_shape) err_psnr = torch.zeros(grid_shape) err_ssim = torch.zeros(grid_shape) for grid_val, nidx in zip(itertools.product(*grid.values()), np.ndindex(*grid_shape)): grid_param_cur = dict(zip(grid.keys(), grid_val)) print( "Current grid parameters (" + str(list(nidx)) + " / " + str(grid_shape) + "): " + str(grid_param_cur), flush=True, ) x_rec = rec_func(y, **grid_param_cur) err[nidx], _ = l2_error(x_rec, x, relative=True, squared=False) err_psnr[nidx] = psnr( rotate_real(x_rec)[:, 0:1, ...], rotate_real(x)[:, 0:1, ...], data_range=rotate_real(x)[:, 0:1, ...].max(), reduction="mean", ) err_ssim[nidx] = ssim( rotate_real(x_rec)[:, 0:1, ...], rotate_real(x)[:, 0:1, ...], data_range=rotate_real(x)[:, 0:1, ...].max(), size_average=True, ) print("Rel. recovery error: {:1.2e}".format(err[nidx]), flush=True) print("PSNR: {:.2f}".format(err_psnr[nidx]), flush=True) print("SSIM: {:.2f}".format(err_ssim[nidx]), flush=True) if err[nidx] < err_min: grid_param = grid_param_cur err_min = err[nidx] return grid_param, err_min, err, err_psnr, err_ssim
def err_measure_l2(x1, x2): """ L2 error wrapper function. """ return l2_error(x1, x2, relative=True, squared=False)[1].squeeze()