def setup_session(self, overwrite=False, timestamp=False): """Sets up training session directory Args: overwrite (bool): if True, overwrites existing directory (default: False) timestamp (bool): if True, adds timestamp to directory name (default: False) """ session_name = self.session_name if timestamp: session_name = session_name + "_" + time.strftime("%Y%m%d-%H%M%S") io.mkdir(session_name, ConfigFile.bin_dir, overwrite) session_dir = os.path.join(ConfigFile.bin_dir, session_name) io.mkdir(ConfigFile.checkpoints_dirname, session_dir) io.mkdir(ConfigFile.tensorboard_dirname, session_dir) io.mkdir(ConfigFile.observations_dirname, session_dir) io.mkdir(ConfigFile.scores_dirname, session_dir) ConfigFile._write_gitignore(session_dir)
def on_epoch_end(self, epoch, logs={}): observations_subdir_format = ConfigFile.observations_subdir_format.format( epoch=epoch + 1, **logs) src_filepath = self.src_filepath.format(epoch=epoch + 1, **logs) tgt_filepath = self.tgt_filepath.format(epoch=epoch + 1, **logs) pred_filepath = self.pred_filepath.format(epoch=epoch + 1, **logs) grad_x_filepath = self.grad_x_filepath.format(epoch=epoch + 1, **logs) grad_y_filepath = self.grad_y_filepath.format(epoch=epoch + 1, **logs) pred_seg_filepath = self.pred_seg_filepath.format(epoch=epoch + 1, **logs) observations_dir = os.path.join(self.session_dir, ConfigFile.observations_dirname) io.mkdir(observations_subdir_format, observations_dir) src_gen = loader.preprocess_scans([self.src_id], *self.input_shape) tgt_gen = loader.preprocess_scans([self.tgt_id], *self.input_shape) src = next(src_gen)[0][np.newaxis, :, :, :, np.newaxis] tgt = next(tgt_gen)[0][np.newaxis, :, :, :, np.newaxis] if self.use_segmentation: tgt_seg_gen = loader.preprocess_segmentations([self.tgt_id], *self.input_shape) tgt_seg = next(tgt_seg_gen)[0][np.newaxis, :, :, :, np.newaxis] output = self.model.predict([src, tgt, tgt_seg]) else: output = self.model.predict([src, tgt]) fig, _ = handler.display_n_slices(src.squeeze(), n=4, return_fig=True) fig.savefig( os.path.join(observations_dir, observations_subdir_format, src_filepath)) plt.close() fig, _ = handler.display_n_slices(tgt.squeeze(), n=4, return_fig=True) fig.savefig( os.path.join(observations_dir, observations_subdir_format, tgt_filepath)) plt.close() fig, _ = handler.display_n_slices(output[0].squeeze(), n=4, return_fig=True) fig.savefig( os.path.join(observations_dir, observations_subdir_format, pred_filepath)) plt.close() fig, _ = handler.display_n_slices(output[1].squeeze()[:, :, :, 0], n=4, return_fig=True) fig.savefig( os.path.join(observations_dir, observations_subdir_format, grad_x_filepath)) plt.close() fig, _ = handler.display_n_slices(output[1].squeeze()[:, :, :, 1], n=4, return_fig=True) fig.savefig( os.path.join(observations_dir, observations_subdir_format, grad_y_filepath)) plt.close() if self.use_segmentation: fig, _ = handler.display_n_slices(output[2].squeeze(), n=4, return_fig=True) fig.savefig( os.path.join(observations_dir, observations_subdir_format, pred_seg_filepath)) plt.close() del fig, _