def f(x): postprocess = importance.PostProcess( s.heatmin, x, importancePostUpscale, LOSS_BORDER // importancePostUpscale, 'basic') importance_map2 = postprocess( importance_map)[0].unsqueeze(1) sampling_mask = ( importance_map2 >= pattern).to( dtype=importance_map.dtype) samples = torch.mean( sampling_mask).item() return samples
def benchmark(scene): DEBUG = False IMAGE_EXPORT = [(512, 512)] #[(2**9, 1024)] # screen, volume resolution #SETTINGS SCREEN_RESOLUTIONS = [2**i for i in range(6, 12)] print("Screen resolutions: ", SCREEN_RESOLUTIONS) MIN_VOLUME_RESOLUTION = 32 NUM_SAMPLES = 50 NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-dvr-modeldir/" NETWORK = ("network", NETWORK_DIR + "adapDvr5-rgb-temp001-perc01-epoch500") VOLUME_FOLDER = "../../isosurface-super-resolution-data/volumes/cvol-filtered/" SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5" SETTINGS_FOLDER = "../network/video/" device_cpu = torch.device("cpu") device_gpu = torch.device("cuda") # load sampling pattern print("load sampling patterns") SAMPLING_PATTERNS = ['halton', 'plastic', 'random', 'regular'] with h5py.File(SAMPLING_FILE, 'r') as f: SAMPLING_PATTERN = torch.from_numpy( f['plastic'][...]).to(device_gpu).unsqueeze(0) # load networks print("Load networks") importanceNetwork = ImportanceModel(NETWORK, device_gpu) reconNetwork = ReconstructionModel(NETWORK, device_gpu) # scenes if scene == 1: # EJECTA 512 VOLUME = "snapshot_272_512.cvol" STEPSIZE = 0.125 POSITION_SAMPLER = lambda: (1.1**0.3) * randomPointOnSphere() SETTINGS = "Dvr-Ejecta-settings.json" UPSCALING = 8 POSTPROCESS = importance.PostProcess(0.001, 0.05, 1, 0, 'basic') IMAGE_PATH = "exportDvrEjecta512_%d_%d_%d.png" OUTPUT_FILE = "../result-stats/DvrBenchmarkEjecta512.tsv" elif scene == 2: # RM VOLUME = "ppmt273_1024.cvol" STEPSIZE = 0.25 def rmPositionSampler(): pos = (1.1**0.3) * randomPointOnSphere() pos[2] = -abs(pos[2]) return pos POSITION_SAMPLER = rmPositionSampler SETTINGS = "Dvr-RM-settings.json" UPSCALING = 8 POSTPROCESS = importance.PostProcess(0.001, 0.05, 1, 0, 'basic') IMAGE_PATH = "exportDvrRM1024_%d_%d_%d.png" OUTPUT_FILE = "../result-stats/DvrBenchmarkRM1024.tsv" elif scene == 3: # RM VOLUME = "cleveland70.cvol" STEPSIZE = 0.25 POSITION_SAMPLER = lambda: (1.1**2.4) * randomPointOnSphere() SETTINGS = "Dvr-Thorax-settings.json" UPSCALING = 8 POSTPROCESS = importance.PostProcess(0.001, 0.05, 1, 0, 'basic') IMAGE_PATH = "exportDvrThorax512_%d_%d_%d.png" OUTPUT_FILE = "../result-stats/DvrBenchmarkThorax512.tsv" ################################################### ######## RUN BECHMARK############################## ################################################### # no gradients anywhere torch.set_grad_enabled(False) # load volume err = torch.ops.renderer.load_volume_from_binary(VOLUME_FOLDER + VOLUME) assert err == 1 resX, resY, resZ = torch.ops.renderer.get_volume_resolution() print("volume resolution:", resX, resY, resZ) minRes = max(resX, resY, resZ) numMipmapLevels = 0 while minRes >= MIN_VOLUME_RESOLUTION: numMipmapLevels += 1 minRes = minRes // 2 print("Num mipmap levels:", numMipmapLevels) # load settings settings = inference.RenderSettings() camera = inference.Camera(512, 512, [0, 0, -1]) with open(SETTINGS_FOLDER + SETTINGS, "r") as f: o = json.load(f) settings.from_dict(o) camera.from_dict(o['Camera']) settings.update_camera(camera) settings.RENDER_MODE = 2 # run scenes start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) with open(OUTPUT_FILE, "w") as stat_file: stat_file.write( "VolumeResolution\tScreenResolution\tRenderingLowMillis\tImportanceMillis\tRenderingSamplesMillis\tReconstructionMillis\tRenderingHighMillis\tSamplePercentage\n" ) with std_out_err_redirect_tqdm() as orig_stdout: for mipmapLevel in range(numMipmapLevels): # load mipmap level if mipmapLevel > 0: torch.ops.renderer.create_mipmap_level( mipmapLevel, "average") settings.MIPMAP_LEVEL = mipmapLevel volumeResolution = max(resX, resY, resZ) / (2**mipmapLevel) for resolution in SCREEN_RESOLUTIONS: settings.RESOLUTION = [resolution, resolution] settings.VIEWPORT = [0, 0, resolution, resolution] print("volume resolution: %d, screen resolution %d" % (volumeResolution, resolution)) # loop over sample positions renderingLowMillis = 0 importanceMillis = 0 renderingSamplesMillis = 0 reconstructionMillis = 0 renderingHighMillis = 0 samplePercentage = 0 for i in trange(NUM_SAMPLES, desc='Samples', file=orig_stdout, dynamic_ncols=True, leave=True): pos = list(POSITION_SAMPLER()) settings.CAM_ORIGIN_START = pos settings.CAM_ORIGIN_END = pos settings.send() render_settings = settings.clone() render_settings.send() # high torch.cuda.synchronize() start.record() high_res = torch.ops.renderer.render() end.record() torch.cuda.synchronize() renderingHighMillis += start.elapsed_time(end) if (resolution, volumeResolution) in IMAGE_EXPORT: filename = IMAGE_PATH % (resolution, volumeResolution, i) image = high_res[0:3, :, :] image = image.detach().cpu().numpy().transpose( (1, 2, 0)) imageio.imwrite(filename, image) print("Image saved to %s" % filename) # low settingsTmp = settings.clone() settingsTmp.downsampling = UPSCALING settingsTmp.send() torch.cuda.synchronize() start.record() low_res = torch.ops.renderer.render() end.record() torch.cuda.synchronize() renderingLowMillis += start.elapsed_time(end) # prepare for importance map low_res = low_res.unsqueeze(0) low_res_input = low_res[:, :-2, :, :] previous_input = torch.zeros( 1, 1, low_res_input.shape[2] * importanceNetwork.networkUpscaling(), low_res_input.shape[3] * importanceNetwork.networkUpscaling(), dtype=low_res_input.dtype, device=low_res_input.device) # compute importance map torch.cuda.synchronize() start.record() importance_map = importanceNetwork.call( low_res_input[:, 0:5, :, :], previous_input) end.record() torch.cuda.synchronize() importanceMillis += start.elapsed_time(end) if len(importance_map.shape) == 3: importance_map = importance_map.unsqueeze(1) if DEBUG: print("importance map min=%f, max=%f" % (torch.min(importance_map), torch.max(importance_map))) # prepare sampling settings.send() pattern = SAMPLING_PATTERN[:, :importance_map. shape[-2], :importance_map. shape[-1]] normalized_importance_map = POSTPROCESS( importance_map)[0] if DEBUG: print("normalized importance map min=%f, max=%f" % (torch.min(normalized_importance_map), torch.max(normalized_importance_map))) print("pattern min=%f, max=%f" % (torch.min(pattern), torch.max(pattern))) sampling_mask = normalized_importance_map > pattern sample_positions = torch.nonzero(sampling_mask[0].t_()) sample_positions = sample_positions.to( torch.float32).transpose(0, 1).contiguous() samplePercentage += sample_positions.size(1) / ( importance_map.shape[-2] * importance_map.shape[-1]) if DEBUG: print("sample count: %d" % sample_positions.size(1)) # do the sampling torch.cuda.synchronize() start.record() sample_data = torch.ops.renderer.render_samples( sample_positions) reconstruction_input = torch.ops.renderer.scatter_samples( sample_positions, sample_data, resolution, resolution, [0] * 10) end.record() torch.cuda.synchronize() renderingSamplesMillis += start.elapsed_time(end) # reconstruction reconstruction_input = reconstruction_input[: 9, :, :].unsqueeze( 0) previous_input = torch.zeros( 1, 8, reconstruction_input.shape[2], reconstruction_input.shape[3], dtype=reconstruction_input.dtype, device=reconstruction_input.device) torch.cuda.synchronize() start.record() reconNetwork.call(reconstruction_input, sampling_mask, previous_input) end.record() torch.cuda.synchronize() reconstructionMillis += start.elapsed_time(end) # write stats renderingLowMillis /= NUM_SAMPLES importanceMillis /= NUM_SAMPLES renderingSamplesMillis /= NUM_SAMPLES reconstructionMillis /= NUM_SAMPLES renderingHighMillis /= NUM_SAMPLES samplePercentage /= NUM_SAMPLES stat_file.write( "%d\t%d\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\n" % (volumeResolution, resolution, renderingLowMillis, importanceMillis, renderingSamplesMillis, reconstructionMillis, renderingHighMillis, samplePercentage)) stat_file.flush()
def convert( dataset_high : str, # path to the high-resolution dataset dataset_low : str, # path to the low-resolution dataset pattern : str, # path to the dataset with the sample pattern network : str, # path to the importance network output_prefix : str,# prefix of the output datasets settings : List[Settings], # configurations to compute mode : str # 'ISO' or 'DVR' ): device = torch.device('cuda') print("Load Renderer.dll") torch.ops.load_library("./Renderer.dll") print() # load datasets dataset_high_file = h5py.File(dataset_high, 'r') dataset_low_file = h5py.File(dataset_low, 'r') dset_high = dataset_high_file['gt'] dset_low = dataset_low_file['gt'] B, T, C, Hhigh, Whigh = dset_high.shape _, _, _, Hlow, Wlow = dset_low.shape upscaling_factor = Hhigh // Hlow print("number of samples:", B) print("with timesteps: ", T) print("a resolution of: C=%d, H=%d, W=%d"%(C, Hhigh, Whigh)) print("and an upscaling factor of: ", upscaling_factor) print() # load pattern pattern_file = h5py.File(pattern, 'r') print("available sampling pattern: ", pattern_file.keys()) patterns = dict() pattern_min_size = 1<<20 for key in pattern_file.keys(): patterns[key] = torch.from_numpy(pattern_file[key][...]).unsqueeze(0).to(device) print("Pattern", key, "loaded, shape:", patterns[key].shape) pattern_min_size = min(pattern_min_size, patterns[key].shape[0], patterns[key].shape[1]) print() # load network print("Load network") checkpoint = torch.load(network) importanceModel = checkpoint['importanceModel'] importanceModel.to(device) print("Network loaded") parameters = checkpoint['parameters'] importanceNetUpscale = parameters['importanceNetUpscale'] print("network upscaling:", importanceNetUpscale) if importanceNetUpscale > upscaling_factor: print("Network upscaling is larger than the the dataset upscaling. Terminate") return importancePostUpscale = upscaling_factor // importanceNetUpscale print("post upscaling:", importancePostUpscale) if mode == 'ISO': channels_low = 5 # mask, normal x, normal y, normal z, depth elif mode == 'DVR': channels_low = 8 # rgba, normal xyz, depth else: raise ValueError("Unknown mode '%s', 'ISO' or 'DVR' expected"%mode) print() # create postprocess for each setting importancePostprocess = [ importance.PostProcess( s.importanceMin, s.importanceMean, importancePostUpscale, parameters['lossBorderPadding'] // importancePostUpscale, 'basic') for s in settings] # open outputs with torch.no_grad(): with ExitStack() as stack: output_files = [ stack.enter_context(h5py.File(output_prefix + s.output_name, 'w')) for s in settings] dset_high_cropped_shape = tuple( list(dset_high.shape)[:-2] + [dset_high.shape[-2]-2*BORDER_CROP, dset_high.shape[-1]-2*BORDER_CROP]) output_dsets = [ f.create_dataset("gt", dset_high_cropped_shape, dset_high.dtype) for f in output_files] for ds in output_dsets: ds.attrs["Mode"] = mode print("Output files created") print("Process datasets now...") # loop over the dataset widgets=[ ' [', progressbar.Timer(), '] ', progressbar.Bar(), progressbar.Counter(), ' (', progressbar.ETA(), ') ', ] for b in progressbar.progressbar(range(B), widgets=widgets): # get the time slice data_high = torch.from_numpy(dset_high[b, ...]).to(device) data_low = torch.from_numpy(dset_low[b, ...]).to(device) # evaluate the network (time becomes batch) importance_input = data_low[:, :channels_low, :, :] previous_input = torch.zeros( T,1, importance_input.shape[2]*importanceNetUpscale, importance_input.shape[3]*importanceNetUpscale, dtype=data_low.dtype, device=device) importance_input = torch.cat([ importance_input, models.VideoTools.flatten_high(previous_input, importanceNetUpscale) ], dim=1) importance_map = importanceModel(importance_input) # get pattern crop sampling_pattern_x = np.random.randint(0, pattern_min_size - Hhigh) if Hhigh < pattern_min_size else 0 sampling_pattern_y = np.random.randint(0, pattern_min_size - Whigh) if Whigh < pattern_min_size else 0 # loop over all settings for sIdx in range(len(settings)): s = settings[sIdx] # get the pattern sampling_pattern = patterns[s.pattern][:, sampling_pattern_x:sampling_pattern_x+Hhigh, sampling_pattern_y:sampling_pattern_y+Whigh] # normalize and perform the sampling importance_map_post, _ = importancePostprocess[sIdx](importance_map) sample_mask = (importance_map_post >= sampling_pattern).to(dtype=importance_map_post.dtype).unsqueeze(1) # run the inpainting crop = sample_mask * data_high if s.inpainting == Inpainting.FAST: inpainted = importance.fractionalInpaint(crop, sample_mask[:,0,:,:]) elif s.inpainting == Inpainting.PDE: inpainted = importance.pdeInpaint(crop, sample_mask[:,0:1,:,:], cpu=False) else: assert False, "unknown inpainting algorithm" # save the output output_dsets[sIdx][b,...] = inpainted.cpu().numpy() \ [..., BORDER_CROP:-BORDER_CROP, BORDER_CROP:-BORDER_CROP]
def run(): torch.ops.load_library("./Renderer.dll") ######################### # CONFIGURATION ######################### if 0: OUTPUT_FOLDER = "../result-stats/adaptiveIso2/" DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/" DATASETS = [ #("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"), ("RM", "gt-rendering-rm-v1.hdf5"), #("Human", "gt-rendering-human-v1.hdf5"), #("Thorax", "gt-rendering-thorax-v1.hdf5"), ] NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/" NETWORKS = [ #suffixed with _importance.pt and _recon.pt #("adaptive011", "adaptive011_epoch500"), #title, file prefix ("adaptive019", "adaptive019_epoch470"), ("adaptive023", "adaptive023_epoch300") ] SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5" SAMPLING_PATTERNS = ['halton', 'plastic', 'random'] HEATMAP_MIN = [0.01, 0.05, 0.2] HEATMAP_MEAN = [0.05, 0.1, 0.2, 0.5] UPSCALING = 8 # = networkUp * postUp IMPORTANCE_BORDER = 8 LOSS_BORDER = 32 BATCH_SIZE = 8 elif 0: OUTPUT_FOLDER = "../result-stats/adaptiveIsoEnhance6/" DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/" DATASETS = [ ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"), #("RM", "gt-rendering-rm-v1.hdf5"), #("Human", "gt-rendering-human-v1.hdf5"), #("Thorax", "gt-rendering-thorax-v1.hdf5"), #("Head", "gt-rendering-head.hdf5"), ] NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/" NETWORKS = [ #suffixed with _importance.pt and _recon.pt #title, file prefix #("U-Net (5-4)", "sizes/size5-4_epoch500"), #("Enhance-Net (epoch 50)", "enhance2_imp050_epoch050"), #("Enhance-Net (epoch 400)", "enhance2_imp050_epoch400"), #("Enhance-Net (Thorax)", "enhance_imp050_Thorax_epoch200"), #("Enhance-Net (RM)", "enhance_imp050_RM_epoch200"), #("Imp100", "enhance4_imp100_epoch300"), #("Imp100res", "enhance4_imp100res_epoch230"), ("Imp100res+N", "enhance4_imp100res+N_epoch300"), ("Imp100+N", "enhance4_imp100+N_epoch300"), #("Imp100+N-res", "enhance4_imp100+N-res_epoch300"), #("Imp100+N-resInterp", "enhance4_imp100+N-resInterp_epoch300"), #("U-Net (5-4)", "size5-4_epoch500"), #("U-Net (5-3)", "size5-3_epoch500"), #("U-Net (4-4)", "size4-4_epoch500"), ] # Test if it is better to post-train with dense networks and PDE inpainting POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/dense-modeldir/" POSTTRAIN_NETWORKS = [ # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'} #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde") ] SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5" SAMPLING_PATTERNS = ['plastic'] HEATMAP_MIN = [0.002] HEATMAP_MEAN = [ 0.05 ] #[0.01, 0.02, 0.03, 0.04, 0.06, 0.08, 0.1, 0.2, 0.3, 0.5, 0.8, 1.0] USE_BINARY_SEARCH_ON_MEAN = True UPSCALING = 8 # = networkUp * postUp IMPORTANCE_BORDER = 8 LOSS_BORDER = 32 BATCH_SIZE = 4 elif 0: OUTPUT_FOLDER = "../result-stats/adaptiveIsoEnhance5Sampling/" DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/" DATASETS = [ ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"), ] NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/" NETWORKS = [ #suffixed with _importance.pt and _recon.pt #title, file prefix ("Enhance-Net (epoch 400)", "enhance2_imp050_epoch400"), ] # Test if it is better to post-train with dense networks and PDE inpainting POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/dense-modeldir/" POSTTRAIN_NETWORKS = [ # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'} #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde") ] SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5" SAMPLING_PATTERNS = ['halton', 'plastic', 'random', 'regular'] #SAMPLING_PATTERNS = ['regular'] HEATMAP_MIN = [0.002] HEATMAP_MEAN = [0.05] USE_BINARY_SEARCH_ON_MEAN = True UPSCALING = 8 # = networkUp * postUp IMPORTANCE_BORDER = 8 LOSS_BORDER = 32 BATCH_SIZE = 4 elif 1: OUTPUT_FOLDER = "../result-stats/adaptiveIsoEnhance8Sampling/" DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/" DATASETS = [ ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"), #("RM", "gt-rendering-rm-v1.hdf5"), #("Human", "gt-rendering-human-v1.hdf5"), #("Thorax", "gt-rendering-thorax-v1.hdf5"), #("Head", "gt-rendering-head.hdf5"), ] NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/" NETWORKS = [ #suffixed with _importance.pt and _recon.pt #title, file prefix ("regular", "enhance7_regular_epoch190"), ("random", "enhance7_random_epoch190"), ("halton", "enhance7_halton_epoch190"), ("plastic", "enhance7_plastic_epoch190"), ] # Test if it is better to post-train with dense networks and PDE inpainting POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/dense-modeldir/" POSTTRAIN_NETWORKS = [ # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'} #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde") ] SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5" SAMPLING_PATTERNS = ['regular', 'random', 'halton', 'plastic'] HEATMAP_MIN = [0.002] HEATMAP_MEAN = [ 0.05 ] #[0.01, 0.02, 0.03, 0.04, 0.06, 0.08, 0.1, 0.2, 0.3, 0.5, 0.8, 1.0] USE_BINARY_SEARCH_ON_MEAN = True UPSCALING = 8 # = networkUp * postUp IMPORTANCE_BORDER = 8 LOSS_BORDER = 32 BATCH_SIZE = 4 elif 0: OUTPUT_FOLDER = "../result-stats/adaptiveIsoImp/" DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/" DATASETS = [ ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"), #("RM", "gt-rendering-rm-v1.hdf5"), #("Human", "gt-rendering-human-v1.hdf5"), #("Thorax", "gt-rendering-thorax-v1.hdf5"), ] NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/imp/" NETWORKS = [ #suffixed with _importance.pt and _recon.pt #("adaptive011", "adaptive011_epoch500"), #title, file prefix ("imp005", "imp005_epoch500"), ("imp010", "imp010_epoch500"), ("imp020", "imp020_epoch500"), ("imp050", "imp050_epoch500"), ] SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5" SAMPLING_PATTERNS = ['halton'] HEATMAP_MIN = [0.002] HEATMAP_MEAN = [0.005, 0.01, 0.02, 0.05, 0.1] USE_BINARY_SEARCH_ON_MEAN = True UPSCALING = 8 # = networkUp * postUp IMPORTANCE_BORDER = 8 LOSS_BORDER = 32 BATCH_SIZE = 16 ######################### # LOADING ######################### device = torch.device("cuda") # Load Networks IMPORTANCE_BASELINE1 = "ibase1" IMPORTANCE_BASELINE2 = "ibase2" IMPORTANCE_BASELINE3 = "ibase3" RECON_BASELINE = "rbase" # load importance model print("load importance networks") class ImportanceModel: def __init__(self, file): if file == IMPORTANCE_BASELINE1: self._net = importance.UniformImportanceMap(1, 0.5) self._upscaling = 1 self._name = "constant" self.disableTemporal = True self._requiresPrevious = False elif file == IMPORTANCE_BASELINE2: self._net = importance.GradientImportanceMap( 1, (1, 1), (2, 1), (3, 1)) self._upscaling = 1 self._name = "curvature" self.disableTemporal = True self._requiresPrevious = False else: self._name = file[0] file = os.path.join(NETWORK_DIR, file[1] + "_importance.pt") extra_files = torch._C.ExtraFilesMap() extra_files['settings.json'] = "" self._net = torch.jit.load(file, map_location=device, _extra_files=extra_files) settings = json.loads(extra_files['settings.json']) self._upscaling = settings['networkUpscale'] self._requiresPrevious = settings.get("requiresPrevious", False) self.disableTemporal = settings.get("disableTemporal", True) def networkUpscaling(self): return self._upscaling def name(self): return self._name def __repr__(self): return self.name() def call(self, input, prev_warped_out): if self._requiresPrevious: input = torch.cat([ input, models.VideoTools.flatten_high(prev_warped_out, self._upscaling) ], dim=1) input = F.pad(input, [IMPORTANCE_BORDER] * 4, 'constant', 0) output = self._net(input) # the network call output = F.pad(output, [-IMPORTANCE_BORDER * self._upscaling] * 4, 'constant', 0) return output class LuminanceImportanceModel: def __init__(self): self.disableTemporal = True def setTestFile(self, filename): importance_file = filename[:-5] + "-luminanceImportance.hdf5" if os.path.exists(importance_file): self._exist = True self._file = h5py.File(importance_file, 'r') self._dset = self._file['importance'] else: self._exist = False self._file = None self._dset = None def isAvailable(self): return self._exist def setIndices(self, indices: torch.Tensor): assert len(indices.shape) == 1 self._indices = list(indices.cpu().numpy()) def setTime(self, time): self._time = time def networkUpscaling(self): return UPSCALING def name(self): return "luminance-contrast" def __repr__(self): return self.name() def call(self, input, prev_warped_out): B, C, H, W = input.shape if not self._exist: return torch.ones(B, 1, H, W, dtype=input.dtype, device=input.device) outputs = [] for idx in self._indices: outputs.append( torch.from_numpy(self._dset[idx, self._time, ...]).to(device=input.device)) return torch.stack(outputs, dim=0) importanceBaseline1 = ImportanceModel(IMPORTANCE_BASELINE1) importanceBaseline2 = ImportanceModel(IMPORTANCE_BASELINE2) importanceBaselineLuminance = LuminanceImportanceModel() importanceModels = [ImportanceModel(f) for f in NETWORKS] # load reconstruction networks print("load reconstruction networks") class ReconstructionModel: def __init__(self, file): if file == RECON_BASELINE: class Inpainting(nn.Module): def forward(self, x, mask): input = x[:, 0:6, :, :].contiguous( ) # mask, normal xyz, depth, ao mask = x[:, 6, :, :].contiguous() return torch.ops.renderer.fast_inpaint(mask, input) self._net = Inpainting() self._upscaling = 1 self._name = "inpainting" self.disableTemporal = True else: self._name = file[0] file = os.path.join(NETWORK_DIR, file[1] + "_recon.pt") extra_files = torch._C.ExtraFilesMap() extra_files['settings.json'] = "" self._net = torch.jit.load(file, map_location=device, _extra_files=extra_files) self._settings = json.loads(extra_files['settings.json']) self.disableTemporal = False requiresMask = self._settings.get('expectMask', False) if self._settings.get("interpolateInput", False): self._originalNet = self._net class Inpainting2(nn.Module): def __init__(self, orignalNet, requiresMask): super().__init__() self._n = orignalNet self._requiresMask = requiresMask def forward(self, x, mask): input = x[:, 0:6, :, :].contiguous( ) # mask, normal xyz, depth, ao mask = x[:, 6, :, :].contiguous() inpainted = torch.ops.renderer.fast_inpaint( mask, input) x[:, 0:6, :, :] = inpainted if self._requiresMask: return self._n(x, mask) else: return self._n(x) self._net = Inpainting2(self._originalNet, requiresMask) def name(self): return self._name def __repr__(self): return self.name() def call(self, input, mask, prev_warped_out): input = torch.cat([input, prev_warped_out], dim=1) output = self._net(input, mask) return output class ReconstructionModelPostTrain: """ Reconstruction model that are trained as dense reconstruction networks after the adaptive training. They don't recive the sampling mask as input, but can start with PDE-based inpainting """ def __init__(self, name: str, model_path: str, inpainting: str): assert inpainting == 'fast' or inpainting == 'pde', "inpainting must be either 'fast' or 'pde', but got %s" % inpainting self._inpainting = inpainting self._name = name file = os.path.join(POSTTRAIN_NETWORK_DIR, model_path) extra_files = torch._C.ExtraFilesMap() extra_files['settings.json'] = "" self._net = torch.jit.load(file, map_location=device, _extra_files=extra_files) self._settings = json.loads(extra_files['settings.json']) assert self._settings.get( 'upscale_factor', None) == 1, "selected file is not a 1x SRNet" self.disableTemporal = False def name(self): return self._name def __repr__(self): return self.name() def call(self, input, prev_warped_out): # no sampling and no AO input_no_sampling = input[:, 0:5, :, :].contiguous( ) # mask, normal xyz, depth sampling_mask = input[:, 6, :, :].contiguous() # perform inpainting if self._inpainting == 'pde': inpainted = torch.ops.renderer.pde_inpaint( sampling_mask, input_no_sampling, 200, 1e-4, 5, 2, # m0, epsilon, m1, m2 0, # mc -> multigrid recursion count. =0 disables the multigrid hierarchy 9, 0) # ms, m3 else: inpainted = torch.ops.renderer.fast_inpaint( sampling_mask, input_no_sampling) # run network input = torch.cat([inpainted, prev_warped_out], dim=1) output = self._net(input) if isinstance(output, tuple): output = output[0] return output reconBaseline = ReconstructionModel(RECON_BASELINE) reconModels = [ReconstructionModel(f) for f in NETWORKS] reconPostModels = [ ReconstructionModelPostTrain(name, file, inpainting) for (name, file, inpainting) in POSTTRAIN_NETWORKS ] allReconModels = reconModels + reconPostModels NETWORK_COMBINATIONS = \ [(importanceBaseline1, reconBaseline), (importanceBaseline2, reconBaseline)] + \ [(importanceBaselineLuminance, reconBaseline)] + \ [(importanceBaseline1, reconNet) for reconNet in allReconModels] + \ [(importanceBaseline2, reconNet) for reconNet in allReconModels] + \ [(importanceBaselineLuminance, reconNet) for reconNet in allReconModels] + \ [(importanceNet, reconBaseline) for importanceNet in importanceModels] + \ list(zip(importanceModels, reconModels)) + \ [(importanceNet, reconPostModel) for importanceNet in importanceModels for reconPostModel in reconPostModels] #NETWORK_COMBINATIONS = list(zip(importanceModels, reconModels)) print("Network combinations:") for (i, r) in NETWORK_COMBINATIONS: print(" %s - %s" % (i.name(), r.name())) # load sampling patterns print("load sampling patterns") with h5py.File(SAMPLING_FILE, 'r') as f: sampling_pattern = dict([(name, torch.from_numpy(f[name][...]).to(device)) \ for name in SAMPLING_PATTERNS]) # create shading shading = ScreenSpaceShading(device) shading.fov(30) shading.ambient_light_color(np.array([0.1, 0.1, 0.1])) shading.diffuse_light_color(np.array([1.0, 1.0, 1.0])) shading.specular_light_color(np.array([0.0, 0.0, 0.0])) shading.specular_exponent(16) shading.light_direction(np.array([0.1, 0.1, 1.0])) shading.material_color(np.array([1.0, 0.3, 0.0])) AMBIENT_OCCLUSION_STRENGTH = 1.0 shading.ambient_occlusion(1.0) shading.inverse_ao = False #heatmap HEATMAP_CFG = [(min, mean) for min in HEATMAP_MIN for mean in HEATMAP_MEAN if min < mean] print("heatmap configs:", HEATMAP_CFG) ######################### # DEFINE STATISTICS ######################### ssimLoss = SSIM(size_average=False) ssimLoss.to(device) psnrLoss = PSNR() psnrLoss.to(device) lpipsColor = lpips.PerceptualLoss(model='net-lin', net='alex', use_gpu=True) MIN_FILLING = 0.05 NUM_BINS = 200 class Statistics: def __init__(self): self.histogram_color_withAO = np.zeros(NUM_BINS, dtype=np.float64) self.histogram_color_noAO = np.zeros(NUM_BINS, dtype=np.float64) self.histogram_depth = np.zeros(NUM_BINS, dtype=np.float64) self.histogram_normal = np.zeros(NUM_BINS, dtype=np.float64) self.histogram_mask = np.zeros(NUM_BINS, dtype=np.float64) self.histogram_ao = np.zeros(NUM_BINS, dtype=np.float64) self.histogram_counter = 0 def create_datasets(self, hdf5_file: h5py.File, stats_name: str, histo_name: str, num_samples: int, extra_info: dict): self.expected_num_samples = num_samples stats_shape = (num_samples, len(list(StatField))) self.stats_file = hdf5_file.require_dataset(stats_name, stats_shape, dtype='f', exact=True) self.stats_file.attrs['NumFields'] = len(list(StatField)) for field in list(StatField): self.stats_file.attrs['Field%d' % field.value] = field.name for key, value in extra_info.items(): self.stats_file.attrs[key] = value self.stats_index = 0 histo_shape = (NUM_BINS, len(list(HistoField))) self.histo_file = hdf5_file.require_dataset(histo_name, histo_shape, dtype='f', exact=True) self.histo_file.attrs['NumFields'] = len(list(HistoField)) for field in list(HistoField): self.histo_file.attrs['Field%d' % field.value] = field.name for key, value in extra_info.items(): self.histo_file.attrs[key] = value def add_timestep_sample(self, pred_mnda, gt_mnda, sampling_mask): """ adds a timestep sample: pred_mnda: prediction: mask, normal, depth, AO gt_mnda: ground truth: mask, normal, depth, AO """ B = pred_mnda.shape[0] #shading shading.ambient_occlusion(AMBIENT_OCCLUSION_STRENGTH) pred_color_withAO = shading(pred_mnda) gt_color_withAO = shading(gt_mnda) shading.ambient_occlusion(0.0) pred_color_noAO = shading(pred_mnda) gt_color_noAO = shading(gt_mnda) #apply border pred_mnda = pred_mnda[:, :, LOSS_BORDER:-LOSS_BORDER, LOSS_BORDER:-LOSS_BORDER] pred_color_withAO = pred_color_withAO[:, :, LOSS_BORDER:-LOSS_BORDER, LOSS_BORDER:-LOSS_BORDER] pred_color_noAO = pred_color_noAO[:, :, LOSS_BORDER:-LOSS_BORDER, LOSS_BORDER:-LOSS_BORDER] gt_mnda = gt_mnda[:, :, LOSS_BORDER:-LOSS_BORDER, LOSS_BORDER:-LOSS_BORDER] gt_color_withAO = gt_color_withAO[:, :, LOSS_BORDER:-LOSS_BORDER, LOSS_BORDER:-LOSS_BORDER] gt_color_noAO = gt_color_noAO[:, :, LOSS_BORDER:-LOSS_BORDER, LOSS_BORDER:-LOSS_BORDER] mask = gt_mnda[:, 0:1, :, :] * 0.5 + 0.5 # PSNR psnr_mask = psnrLoss(pred_mnda[:, 0:1, :, :], gt_mnda[:, 0:1, :, :]).cpu().numpy() psnr_normal = psnrLoss(pred_mnda[:, 1:4, :, :], gt_mnda[:, 1:4, :, :], mask=mask).cpu().numpy() psnr_depth = psnrLoss(pred_mnda[:, 4:5, :, :], gt_mnda[:, 4:5, :, :], mask=mask).cpu().numpy() psnr_ao = psnrLoss(pred_mnda[:, 5:6, :, :], gt_mnda[:, 5:6, :, :], mask=mask).cpu().numpy() psnr_color_withAO = psnrLoss(pred_color_withAO, gt_color_withAO, mask=mask).cpu().numpy() psnr_color_noAO = psnrLoss(pred_color_noAO, gt_color_noAO, mask=mask).cpu().numpy() # SSIM ssim_mask = ssimLoss(pred_mnda[:, 0:1, :, :], gt_mnda[:, 0:1, :, :]).cpu().numpy() pred_mnda = gt_mnda + mask * (pred_mnda - gt_mnda) ssim_normal = ssimLoss(pred_mnda[:, 1:4, :, :], gt_mnda[:, 1:4, :, :]).cpu().numpy() ssim_depth = ssimLoss(pred_mnda[:, 4:5, :, :], gt_mnda[:, 4:5, :, :]).cpu().numpy() ssim_ao = ssimLoss(pred_mnda[:, 5:6, :, :], gt_mnda[:, 5:6, :, :]).cpu().numpy() ssim_color_withAO = ssimLoss(pred_color_withAO, gt_color_withAO).cpu().numpy() ssim_color_noAO = ssimLoss(pred_color_noAO, gt_color_noAO).cpu().numpy() # Perceptual lpips_color_withAO = torch.cat([ lpipsColor( pred_color_withAO[b], gt_color_withAO[b], normalize=True) for b in range(B) ], dim=0).cpu().numpy() lpips_color_noAO = torch.cat([ lpipsColor( pred_color_noAO[b], gt_color_noAO[b], normalize=True) for b in range(B) ], dim=0).cpu().numpy() # Samples samples = torch.mean(sampling_mask, dim=(1, 2, 3)).cpu().numpy() # Write samples to file for b in range(B): assert self.stats_index < self.expected_num_samples, "Adding more samples than specified" self.stats_file[self.stats_index, :] = np.array([ psnr_mask[b], psnr_normal[b], psnr_depth[b], psnr_ao[b], psnr_color_noAO[b], psnr_color_withAO[b], ssim_mask[b], ssim_normal[b], ssim_depth[b], ssim_ao[b], ssim_color_noAO[b], ssim_color_withAO[b], lpips_color_noAO[b], lpips_color_withAO[b], samples[b] ], dtype='f') self.stats_index += 1 # Histogram self.histogram_counter += 1 mask_diff = F.l1_loss(gt_mnda[:, 0, :, :], pred_mnda[:, 0, :, :], reduction='none') histogram, _ = np.histogram(mask_diff.cpu().numpy(), bins=NUM_BINS, range=(0, 1), density=True) self.histogram_mask += ( histogram / (NUM_BINS * B) - self.histogram_mask) / self.histogram_counter #normal_diff = (-F.cosine_similarity(gt_mnda[0,1:4,:,:], pred_mnda[0,1:4,:,:], dim=0)+1)/2 normal_diff = F.l1_loss(gt_mnda[:, 1:4, :, :], pred_mnda[:, 1:4, :, :], reduction='none').sum(dim=0) / 6 histogram, _ = np.histogram(normal_diff.cpu().numpy(), bins=NUM_BINS, range=(0, 1), density=True) self.histogram_normal += (histogram / (NUM_BINS * B) - self.histogram_normal ) / self.histogram_counter depth_diff = F.l1_loss(gt_mnda[:, 4, :, :], pred_mnda[:, 4, :, :], reduction='none') histogram, _ = np.histogram(depth_diff.cpu().numpy(), bins=NUM_BINS, range=(0, 1), density=True) self.histogram_depth += ( histogram / (NUM_BINS * B) - self.histogram_depth) / self.histogram_counter ao_diff = F.l1_loss(gt_mnda[:, 5, :, :], pred_mnda[:, 5, :, :], reduction='none') histogram, _ = np.histogram(ao_diff.cpu().numpy(), bins=NUM_BINS, range=(0, 1), density=True) self.histogram_ao += (histogram / (NUM_BINS * B) - self.histogram_ao) / self.histogram_counter color_diff = F.l1_loss(gt_color_withAO[:, 0, :, :], pred_color_withAO[:, 0, :, :], reduction='none') histogram, _ = np.histogram(color_diff.cpu().numpy(), bins=NUM_BINS, range=(0, 1), density=True) self.histogram_color_withAO += ( histogram / (NUM_BINS * B) - self.histogram_color_withAO) / self.histogram_counter color_diff = F.l1_loss(gt_color_noAO[:, 0, :, :], pred_color_noAO[:, 0, :, :], reduction='none') histogram, _ = np.histogram(color_diff.cpu().numpy(), bins=NUM_BINS, range=(0, 1), density=True) self.histogram_color_noAO += ( histogram / (NUM_BINS * B) - self.histogram_color_noAO) / self.histogram_counter def close_stats_file(self): self.stats_file.attrs['NumEntries'] = self.stats_index def write_histogram(self): """ After every sample for the current dataset was processed, write a histogram of the errors in a new file """ for i in range(NUM_BINS): self.histo_file[i, :] = np.array([ i / NUM_BINS, (i + 1) / NUM_BINS, self.histogram_mask[i], self.histogram_normal[i], self.histogram_depth[i], self.histogram_ao[i], self.histogram_color_withAO[i], self.histogram_color_noAO[i] ]) ######################### # DATASET ######################### class FullResDataset(torch.utils.data.Dataset): def __init__(self, file): self.hdf5_file = h5py.File(file, 'r') self.dset = self.hdf5_file['gt'] print("Dataset shape:", self.dset.shape) def __len__(self): return self.dset.shape[0] def num_timesteps(self): return self.dset.shape[1] def __getitem__(self, idx): return (self.dset[idx, ...], np.array(idx)) ######################### # COMPUTE STATS for each dataset ######################### for dataset_name, dataset_file in DATASETS: dataset_file = os.path.join(DATASET_PREFIX, dataset_file) print("Compute statistics for", dataset_name) # init luminance importance map importanceBaselineLuminance.setTestFile(dataset_file) if importanceBaselineLuminance.isAvailable(): print("Luminance-contrast importance map is available") # create output file os.makedirs(OUTPUT_FOLDER, exist_ok=True) output_file = os.path.join(OUTPUT_FOLDER, dataset_name + '.hdf5') print("Save to", output_file) with h5py.File(output_file, 'a') as output_hdf5_file: # load dataset set = FullResDataset(dataset_file) data_loader = torch.utils.data.DataLoader(set, batch_size=BATCH_SIZE, shuffle=False) # define statistics StatsCfg = collections.namedtuple( "StatsCfg", "stats importance recon heatmin heatmean pattern") statistics = [] for (inet, rnet) in NETWORK_COMBINATIONS: for (heatmin, heatmean) in HEATMAP_CFG: for pattern in SAMPLING_PATTERNS: stats_info = { 'importance': inet.name(), 'reconstruction': rnet.name(), 'heatmin': heatmin, 'heatmean': heatmean, 'pattern': pattern } stats_filename = "Stats_%s_%s_%03d_%03d_%s" % ( inet.name(), rnet.name(), heatmin * 100, heatmean * 100, pattern) histo_filename = "Histogram_%s_%s_%03d_%03d_%s" % ( inet.name(), rnet.name(), heatmin * 100, heatmean * 100, pattern) s = Statistics() s.create_datasets(output_hdf5_file, stats_filename, histo_filename, len(set) * set.num_timesteps(), stats_info) statistics.append( StatsCfg(stats=s, importance=inet, recon=rnet, heatmin=heatmin, heatmean=heatmean, pattern=pattern)) print(len(statistics), " different combinations are performed per sample") # compute statistics try: with torch.no_grad(): num_minibatch = len(data_loader) pg = ProgressBar(num_minibatch, 'Evaluation', length=50) for iteration, (batch, batch_indices) in enumerate( data_loader, 0): pg.print_progress_bar(iteration) batch = batch.to(device) importanceBaselineLuminance.setIndices(batch_indices) B, T, C, H, W = batch.shape # try out each combination for s in statistics: #print(s) # get input to evaluation importanceNetUpscale = s.importance.networkUpscaling( ) importancePostUpscale = UPSCALING // importanceNetUpscale crop_low = torch.nn.functional.interpolate( batch.reshape(B * T, C, H, W), scale_factor=1 / UPSCALING, mode='area').reshape(B, T, C, H // UPSCALING, W // UPSCALING) pattern = sampling_pattern[s.pattern][:H, :W] crop_high = batch # loop over timesteps pattern = pattern.unsqueeze(0).unsqueeze(0) previous_importance = None previous_output = None reconstructions = [] for j in range(T): importanceBaselineLuminance.setTime(j) # extract flow (always the last two channels of crop_high) flow = crop_high[:, j, C - 2:, :, :] # compute importance map importance_input = crop_low[:, j, :5, :, :] if j == 0 or s.importance.disableTemporal: previous_input = torch.zeros( B, 1, importance_input.shape[2] * importanceNetUpscale, importance_input.shape[3] * importanceNetUpscale, dtype=crop_high.dtype, device=crop_high.device) else: flow_low = F.interpolate( flow, scale_factor=1 / importancePostUpscale) previous_input = models.VideoTools.warp_upscale( previous_importance, flow_low, 1, False) importance_map = s.importance.call( importance_input, previous_input) if len(importance_map.shape) == 3: importance_map = importance_map.unsqueeze( 1) previous_importance = importance_map target_mean = s.heatmean if USE_BINARY_SEARCH_ON_MEAN: # For regular sampling, the normalization does not work properly, # use binary search on the heatmean instead def f(x): postprocess = importance.PostProcess( s.heatmin, x, importancePostUpscale, LOSS_BORDER // importancePostUpscale, 'basic') importance_map2 = postprocess( importance_map)[0].unsqueeze(1) sampling_mask = ( importance_map2 >= pattern).to( dtype=importance_map.dtype) samples = torch.mean( sampling_mask).item() return samples target_mean = binarySearch( f, s.heatmean, s.heatmean, 10, 0, 1) #print("Binary search for #samples, mean start={}, result={} with samples={}, original={}". # format(s.heatmean, s.heatmean, f(target_mean), f(s.heatmean))) # normalize and upscale importance map postprocess = importance.PostProcess( s.heatmin, target_mean, importancePostUpscale, LOSS_BORDER // importancePostUpscale, 'basic') importance_map = postprocess( importance_map)[0].unsqueeze(1) #print("mean:", torch.mean(importance_map).item()) # create samples sample_mask = (importance_map >= pattern).to( dtype=importance_map.dtype) reconstruction_input = torch.cat( ( sample_mask * crop_high[:, j, 0: 5, :, :], # mask, normal x, normal y, normal z, depth sample_mask * torch.ones( B, 1, H, W, dtype=crop_high.dtype, device=crop_high.device), # ao sample_mask), # sample mask dim=1) # warp previous output if j == 0 or s.recon.disableTemporal: previous_input = torch.zeros( B, 6, H, W, dtype=crop_high.dtype, device=crop_high.device) else: previous_input = models.VideoTools.warp_upscale( previous_output, flow, 1, False) # run reconstruction network reconstruction = s.recon.call( reconstruction_input, sample_mask, previous_input) # clamp reconstruction_clamped = torch.cat( [ torch.clamp( reconstruction[:, 0:1, :, :], -1, +1), # mask ScreenSpaceShading.normalize( reconstruction[:, 1:4, :, :], dim=1), torch.clamp( reconstruction[:, 4:5, :, :], 0, +1), # depth torch.clamp(reconstruction[:, 5:6, :, :], 0, +1) # ao ], dim=1) reconstructions.append(reconstruction_clamped) # save for next frame previous_output = reconstruction_clamped #endfor: timesteps # compute statistics reconstructions = torch.cat(reconstructions, dim=0) crops_high = torch.cat( [crop_high[:, j, :6, :, :] for j in range(T)], dim=0) sample_masks = torch.cat([sample_mask] * T, dim=0) s.stats.add_timestep_sample( reconstructions, crops_high, sample_masks) # endfor: statistic # endfor: batch pg.print_progress_bar(num_minibatch) # end no_grad() finally: # close files for s in statistics: s.stats.write_histogram() s.stats.close_stats_file()
############################# # MODEL ############################# #TODO: temporal component print('===> Building model') model = importance.NetworkImportanceMap(opt.networkUpscale, input_channels, model=opt.model, use_bn=opt.useBN, border_padding=opt.border, output_layer=opt.outputLayer) postprocess = importance.PostProcess( opt.minImportance, opt.meanImportance, opt.postUpscale, opt.lossBorderPadding // opt.postUpscale) model.to(device) print('Model:') print(model) if not no_summary: summary(model, input_size=train_set.get_low_res_shape(input_channels), device=device.type) ############################# # SHADING # for now, only used for testing ############################# shading = ScreenSpaceShading(device) shading.fov(30)
def run(): torch.ops.load_library("./Renderer.dll") ######################### # CONFIGURATION ######################### if 1: OUTPUT_FOLDER = "../result-stats/adaptiveDvr3/" DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/" DATASETS = [ ("Ejecta", "gt-dvr-ejecta6-test.hdf5", "gt-dvr-ejecta6-test-screen8x.hdf5"), ("RM", "gt-dvr-rm1-test.hdf5", "gt-dvr-rm1-test-screen8x.hdf5"), ("Thorax", "gt-dvr-thorax2-test.hdf5", "gt-dvr-thorax2-test-screen8x.hdf5"), ] NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-dvr-modeldir/" NETWORKS = [ #suffixed with _importance.pt and _recon.pt #title, file prefix #("v5-temp001", "adapDvr5-rgb-temp001-epoch300"), #("v5-temp010", "adapDvr5-rgb-temp010-epoch300"), #("v5-temp100", "adapDvr5-rgb-temp100-epoch300"), #("v5-temp001-perc", "adapDvr5-rgb-temp001-perc01-epoch300"), ("v5-perc01+bn", "adapDvr5-rgb-perc01-bn-epoch500"), ("v5-perc01-bn", "adapDvr5-rgb-temp001-perc01-epoch500") ] # Test if it is better to post-train with dense networks and PDE inpainting POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-dvr-modeldir/" POSTTRAIN_NETWORKS = [ # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'} #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde") ("v6pr2-noTemp", "ejecta6pr2-plastic05-lpips-noTempCon-epoch500_recon.pt", "fast", False), ("v6pr2-tl2-100", "ejecta6pr2-plastic05-lpips-tl2-100-epoch500_recon.pt", "fast", True) ] SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5" SAMPLING_PATTERNS = ['plastic'] HEATMAP_MIN = [0.002] HEATMAP_MEAN = [ 0.02, 0.05, 0.1, 0.2 ] #[0.01, 0.02, 0.03, 0.04, 0.06, 0.08, 0.1, 0.2, 0.3, 0.5, 0.8, 1.0] USE_BINARY_SEARCH_ON_MEAN = True UPSCALING = 8 # = networkUp * postUp IMPORTANCE_BORDER = 8 LOSS_BORDER = 32 BATCH_SIZE = 1 #2 ######################### # LOADING ######################### device = torch.device("cuda") # Load Networks IMPORTANCE_BASELINE1 = "ibase1" IMPORTANCE_BASELINE2 = "ibase2" RECON_BASELINE = "rbase" # load importance model print("load importance networks") class ImportanceModel: def __init__(self, file): if file == IMPORTANCE_BASELINE1: self._net = importance.UniformImportanceMap(1, 0.5) self._upscaling = 1 self._name = "constant" self.disableTemporal = True self._requiresPrevious = False elif file == IMPORTANCE_BASELINE2: self._net = importance.GradientImportanceMap( 1, (0, 1), (1, 1), (2, 1)) self._upscaling = 1 self._name = "curvature" self.disableTemporal = True self._requiresPrevious = False else: self._name = file[0] file = os.path.join(NETWORK_DIR, file[1] + "_importance.pt") extra_files = torch._C.ExtraFilesMap() extra_files['settings.json'] = "" self._net = torch.jit.load(file, map_location=device, _extra_files=extra_files) settings = json.loads(extra_files['settings.json']) self._upscaling = settings['networkUpscale'] self._requiresPrevious = settings.get("requiresPrevious", False) self.disableTemporal = settings.get("disableTemporal", True) def networkUpscaling(self): return self._upscaling def name(self): return self._name def __repr__(self): return self.name() def call(self, input, prev_warped_out): if self._requiresPrevious: input = torch.cat([ input, models.VideoTools.flatten_high(prev_warped_out, self._upscaling) ], dim=1) input = F.pad(input, [IMPORTANCE_BORDER] * 4, 'constant', 0) output = self._net(input) # the network call output = F.pad(output, [-IMPORTANCE_BORDER * self._upscaling] * 4, 'constant', 0) return output importanceBaseline1 = ImportanceModel(IMPORTANCE_BASELINE1) importanceBaseline2 = ImportanceModel(IMPORTANCE_BASELINE2) importanceModels = [ImportanceModel(f) for f in NETWORKS] # load reconstruction networks print("load reconstruction networks") class ReconstructionModel: def __init__(self, file): if file == RECON_BASELINE: class Inpainting(nn.Module): def forward(self, x, mask): input = x[:, 0:4, :, :].contiguous( ) # rgba, don't use normal xyz, depth mask = x[:, 8, :, :].contiguous() return torch.ops.renderer.fast_inpaint(mask, input) self._net = Inpainting() self._upscaling = 1 self._name = "inpainting" self.disableTemporal = True else: self._name = file[0] file = os.path.join(NETWORK_DIR, file[1] + "_recon.pt") extra_files = torch._C.ExtraFilesMap() extra_files['settings.json'] = "" self._net = torch.jit.load(file, map_location=device, _extra_files=extra_files) self._settings = json.loads(extra_files['settings.json']) self.disableTemporal = False requiresMask = self._settings.get('expectMask', False) if self._settings.get("interpolateInput", False): self._originalNet = self._net class Inpainting2(nn.Module): def __init__(self, orignalNet, requiresMask): super().__init__() self._n = orignalNet self._requiresMask = requiresMask def forward(self, x, mask): input = x[:, 0:8, :, :].contiguous( ) # rgba, normal xyz, depth mask = x[:, 8, :, :].contiguous() inpainted = torch.ops.renderer.fast_inpaint( mask, input) x[:, 0:8, :, :] = inpainted if self._requiresMask: return self._n(x, mask) else: return self._n(x) self._net = Inpainting2(self._originalNet, requiresMask) def name(self): return self._name def __repr__(self): return self.name() def call(self, input, mask, prev_warped_out): input = torch.cat([input, prev_warped_out], dim=1) output = self._net(input, mask) return output class ReconstructionModelPostTrain: """ Reconstruction model that are trained as dense reconstruction networks after the adaptive training. They don't recive the sampling mask as input, but can start with PDE-based inpainting """ def __init__(self, name: str, model_path: str, inpainting: str, has_temporal: bool): assert inpainting == 'fast' or inpainting == 'pde', "inpainting must be either 'fast' or 'pde', but got %s" % inpainting self._inpainting = inpainting self._name = name file = os.path.join(POSTTRAIN_NETWORK_DIR, model_path) extra_files = torch._C.ExtraFilesMap() extra_files['settings.json'] = "" self._net = torch.jit.load(file, map_location=device, _extra_files=extra_files) self._settings = json.loads(extra_files['settings.json']) assert self._settings.get( 'upscale_factor', None) == 1, "selected file is not a 1x SRNet" self.disableTemporal = not has_temporal def name(self): return self._name def __repr__(self): return self.name() def call(self, input, mask, prev_warped_out): # no sampling and no AO input_no_sampling = input[:, 0:8, :, :].contiguous( ) # rgba, normal xyz, depth sampling_mask = mask[:, 0, :, :].contiguous() # perform inpainting if self._inpainting == 'pde': inpainted = torch.ops.renderer.pde_inpaint( sampling_mask, input_no_sampling, 200, 1e-4, 5, 2, # m0, epsilon, m1, m2 0, # mc -> multigrid recursion count. =0 disables the multigrid hierarchy 9, 0) # ms, m3 else: inpainted = torch.ops.renderer.fast_inpaint( sampling_mask, input_no_sampling) # run network if self.disableTemporal: prev_warped_out = torch.zeros_like(prev_warped_out) input = torch.cat([inpainted, prev_warped_out], dim=1) output = self._net(input) if isinstance(output, tuple): output = output[0] return output reconBaseline = ReconstructionModel(RECON_BASELINE) reconModels = [ReconstructionModel(f) for f in NETWORKS] reconPostModels = [ ReconstructionModelPostTrain(name, file, inpainting, has_temporal) for (name, file, inpainting, has_temporal) in POSTTRAIN_NETWORKS ] allReconModels = reconModels + reconPostModels NETWORK_COMBINATIONS = \ [(importanceBaseline1, reconBaseline), (importanceBaseline2, reconBaseline)] + \ [(importanceBaseline1, reconNet) for reconNet in allReconModels] + \ [(importanceBaseline2, reconNet) for reconNet in allReconModels] + \ [(importanceNet, reconBaseline) for importanceNet in importanceModels] + \ list(zip(importanceModels, reconModels)) + \ [(importanceNet, reconPostModel) for importanceNet in importanceModels for reconPostModel in reconPostModels] #NETWORK_COMBINATIONS = list(zip(importanceModels, reconModels)) print("Network combinations:") for (i, r) in NETWORK_COMBINATIONS: print(" %s - %s" % (i.name(), r.name())) # load sampling patterns print("load sampling patterns") with h5py.File(SAMPLING_FILE, 'r') as f: sampling_pattern = dict([(name, torch.from_numpy(f[name][...]).to(device)) \ for name in SAMPLING_PATTERNS]) #heatmap HEATMAP_CFG = [(min, mean) for min in HEATMAP_MIN for mean in HEATMAP_MEAN if min < mean] print("heatmap configs:", HEATMAP_CFG) ######################### # DEFINE STATISTICS ######################### ssimLoss = SSIM(size_average=False) ssimLoss.to(device) psnrLoss = PSNR() psnrLoss.to(device) lpipsColor = lpips.PerceptualLoss(model='net-lin', net='alex', use_gpu=True) MIN_FILLING = 0.05 NUM_BINS = 200 class Statistics: def __init__(self): self.histogram_color = np.zeros(NUM_BINS, dtype=np.float64) self.histogram_alpha = np.zeros(NUM_BINS, dtype=np.float64) self.histogram_counter = 0 def create_datasets(self, hdf5_file: h5py.File, stats_name: str, histo_name: str, num_samples: int, extra_info: dict): self.stats_name = stats_name self.expected_num_samples = num_samples stats_shape = (num_samples, len(list(StatFieldDvr))) self.stats_file = hdf5_file.require_dataset(stats_name, stats_shape, dtype='f', exact=True) self.stats_file.attrs['NumFields'] = len(list(StatFieldDvr)) for field in list(StatFieldDvr): self.stats_file.attrs['Field%d' % field.value] = field.name for key, value in extra_info.items(): self.stats_file.attrs[key] = value self.stats_index = 0 histo_shape = (NUM_BINS, len(list(HistoFieldDvr))) self.histo_file = hdf5_file.require_dataset(histo_name, histo_shape, dtype='f', exact=True) self.histo_file.attrs['NumFields'] = len(list(HistoFieldDvr)) for field in list(HistoFieldDvr): self.histo_file.attrs['Field%d' % field.value] = field.name for key, value in extra_info.items(): self.histo_file.attrs[key] = value def add_timestep_sample(self, pred_rgba, gt_rgba, sampling_mask): """ adds a timestep sample: pred_rgba: prediction rgba gt_rgba: ground truth rgba """ B = pred_rgba.shape[0] #apply border pred_rgba = pred_rgba[:, :, LOSS_BORDER:-LOSS_BORDER, LOSS_BORDER:-LOSS_BORDER] gt_rgba = gt_rgba[:, :, LOSS_BORDER:-LOSS_BORDER, LOSS_BORDER:-LOSS_BORDER] # PSNR psnr_color = psnrLoss(pred_rgba[:, 0:3, :, :], gt_rgba[:, 0:3, :, :]).cpu().numpy() psnr_alpha = psnrLoss(pred_rgba[:, 3:4, :, :], gt_rgba[:, 3:4, :, :]).cpu().numpy() # SSIM ssim_color = ssimLoss(pred_rgba[:, 0:3, :, :], gt_rgba[:, 0:3, :, :]).cpu().numpy() ssim_alpha = ssimLoss(pred_rgba[:, 3:4, :, :], gt_rgba[:, 3:4, :, :]).cpu().numpy() # Perceptual lpips_color = torch.cat([ \ lpipsColor(pred_rgba[b, 0:3, :, :], gt_rgba[b, 0:3, :, :], normalize=True) \ for b in range(B)], dim=0).cpu().numpy() # Samples samples = torch.mean(sampling_mask, dim=(1, 2, 3)).cpu().numpy() # Write samples to file for b in range(B): assert self.stats_index < self.expected_num_samples, "Adding more samples than specified" self.stats_file[self.stats_index, :] = np.array([ psnr_color[b], psnr_alpha[b], ssim_color[b], ssim_alpha[b], lpips_color[b], samples[b] ], dtype='f') self.stats_index += 1 # Histogram self.histogram_counter += 1 color_diff = F.l1_loss(gt_rgba[:, 0:3, :, :], pred_rgba[:, 0:3, :, :], reduction='none').sum(dim=0) / 6 histogram, _ = np.histogram(color_diff.cpu().numpy(), bins=NUM_BINS, range=(0, 1), density=True) self.histogram_color += ( histogram / (NUM_BINS * B) - self.histogram_color) / self.histogram_counter alpha_diff = F.l1_loss(gt_rgba[:, 3, :, :], pred_rgba[:, 3, :, :], reduction='none') histogram, _ = np.histogram(alpha_diff.cpu().numpy(), bins=NUM_BINS, range=(0, 1), density=True) self.histogram_alpha += ( histogram / (NUM_BINS * B) - self.histogram_alpha) / self.histogram_counter def close_stats_file(self): self.stats_file.attrs['NumEntries'] = self.stats_index def write_histogram(self): """ After every sample for the current dataset was processed, write a histogram of the errors in a new file """ for i in range(NUM_BINS): self.histo_file[i, :] = np.array([ i / NUM_BINS, (i + 1) / NUM_BINS, self.histogram_color[i], self.histogram_alpha[i] ]) ######################### # DATASET ######################### class FullResDataset(torch.utils.data.Dataset): def __init__(self, file_high, file_low): self.hdf5_file_high = h5py.File(file_high, 'r') self.dset_high = self.hdf5_file_high['gt'] self.hdf5_file_low = h5py.File(file_low, 'r') self.dset_low = self.hdf5_file_low['gt'] print("Dataset shape:", self.dset_high.shape) def __len__(self): return self.dset_high.shape[0] def num_timesteps(self): return self.dset_high.shape[1] def __getitem__(self, idx): return self.dset_high[idx, ...], self.dset_low[idx, ...] ######################### # COMPUTE STATS for each dataset ######################### for dataset_name, dataset_file_high, dataset_file_low in DATASETS: dataset_file_high = os.path.join(DATASET_PREFIX, dataset_file_high) dataset_file_low = os.path.join(DATASET_PREFIX, dataset_file_low) print("Compute statistics for", dataset_name) # create output file os.makedirs(OUTPUT_FOLDER, exist_ok=True) output_file = os.path.join(OUTPUT_FOLDER, dataset_name + '.hdf5') print("Save to", output_file) with h5py.File(output_file, 'a') as output_hdf5_file: # load dataset set = FullResDataset(dataset_file_high, dataset_file_low) data_loader = torch.utils.data.DataLoader(set, batch_size=BATCH_SIZE, shuffle=False) # define statistics StatsCfg = collections.namedtuple( "StatsCfg", "stats importance recon heatmin heatmean pattern") statistics = [] for (inet, rnet) in NETWORK_COMBINATIONS: for (heatmin, heatmean) in HEATMAP_CFG: for pattern in SAMPLING_PATTERNS: stats_info = { 'importance': inet.name(), 'reconstruction': rnet.name(), 'heatmin': heatmin, 'heatmean': heatmean, 'pattern': pattern } stats_filename = "Stats_%s_%s_%03d_%03d_%s" % ( inet.name(), rnet.name(), heatmin * 1000, heatmean * 1000, pattern) histo_filename = "Histogram_%s_%s_%03d_%03d_%s" % ( inet.name(), rnet.name(), heatmin * 1000, heatmean * 1000, pattern) s = Statistics() s.create_datasets(output_hdf5_file, stats_filename, histo_filename, len(set) * set.num_timesteps(), stats_info) statistics.append( StatsCfg(stats=s, importance=inet, recon=rnet, heatmin=heatmin, heatmean=heatmean, pattern=pattern)) print(len(statistics), " different combinations are performed per sample") # compute statistics try: with torch.no_grad(): num_minibatch = len(data_loader) pg = ProgressBar(num_minibatch, 'Evaluation', length=50) for iteration, (crop_high, crop_low) in enumerate(data_loader, 0): pg.print_progress_bar(iteration) crop_high = crop_high.to(device) crop_low = crop_low.to(device) B, T, C, H, W = crop_high.shape _, _, _, Hlow, Wlow = crop_low.shape assert Hlow * UPSCALING == H # try out each combination for s in statistics: #print(s) # get input to evaluation importanceNetUpscale = s.importance.networkUpscaling( ) importancePostUpscale = UPSCALING // importanceNetUpscale pattern = sampling_pattern[s.pattern][:H, :W] # loop over timesteps pattern = pattern.unsqueeze(0).unsqueeze(0) previous_importance = None previous_output = None reconstructions = [] for j in range(T): # extract flow (always the last two channels of crop_high) flow = crop_high[:, j, C - 2:, :, :] # compute importance map importance_input = crop_low[:, j, :8, :, :] if j == 0 or s.importance.disableTemporal: previous_input = torch.zeros( B, 1, importance_input.shape[2] * importanceNetUpscale, importance_input.shape[3] * importanceNetUpscale, dtype=crop_high.dtype, device=crop_high.device) else: flow_low = F.interpolate( flow, scale_factor=1 / importancePostUpscale) previous_input = models.VideoTools.warp_upscale( previous_importance, flow_low, 1, False) importance_map = s.importance.call( importance_input, previous_input) if len(importance_map.shape) == 3: importance_map = importance_map.unsqueeze( 1) previous_importance = importance_map target_mean = s.heatmean if USE_BINARY_SEARCH_ON_MEAN: # For regular sampling, the normalization does not work properly, # use binary search on the heatmean instead def f(x): postprocess = importance.PostProcess( s.heatmin, x, importancePostUpscale, LOSS_BORDER // importancePostUpscale, 'basic') importance_map2 = postprocess( importance_map)[0].unsqueeze(1) sampling_mask = ( importance_map2 >= pattern).to( dtype=importance_map.dtype) samples = torch.mean( sampling_mask).item() return samples target_mean = binarySearch( f, s.heatmean, s.heatmean, 10, 0, 1) #print("Binary search for #samples, mean start={}, result={} with samples={}, original={}". # format(s.heatmean, s.heatmean, f(target_mean), f(s.heatmean))) # normalize and upscale importance map postprocess = importance.PostProcess( s.heatmin, target_mean, importancePostUpscale, LOSS_BORDER // importancePostUpscale, 'basic') importance_map = postprocess( importance_map)[0].unsqueeze(1) #print("mean:", torch.mean(importance_map).item()) # create samples sample_mask = (importance_map >= pattern).to( dtype=importance_map.dtype) reconstruction_input = torch.cat( ( sample_mask * crop_high[:, j, 0: 8, :, :], # rgba, normal xyz, depth sample_mask), # sample mask dim=1) # warp previous output if j == 0 or s.recon.disableTemporal: previous_input = torch.zeros( B, 4, H, W, dtype=crop_high.dtype, device=crop_high.device) else: previous_input = models.VideoTools.warp_upscale( previous_output, flow, 1, False) # run reconstruction network reconstruction = s.recon.call( reconstruction_input, sample_mask, previous_input) # clamp reconstruction_clamped = torch.clamp( reconstruction, 0, 1) reconstructions.append(reconstruction_clamped) ## test #if j==0: # plt.figure() # plt.imshow(reconstruction_clamped[0,0:3,:,:].cpu().numpy().transpose((1,2,0))) # plt.title(s.stats.stats_name) # plt.show() # save for next frame previous_output = reconstruction_clamped #endfor: timesteps # compute statistics reconstructions = torch.cat(reconstructions, dim=0) crops_high = torch.cat( [crop_high[:, j, :8, :, :] for j in range(T)], dim=0) sample_masks = torch.cat([sample_mask] * T, dim=0) s.stats.add_timestep_sample( reconstructions, crops_high, sample_masks) # endfor: statistic # endfor: batch pg.print_progress_bar(num_minibatch) # end no_grad() finally: # close files for s in statistics: s.stats.write_histogram() s.stats.close_stats_file()