def __init__(self, opt, phase): self.root = opt.dataroot self.scale_factor = opt.scale_factor self.fine_size = opt.fine_size self.n_test = opt.n_test self.n_val = opt.n_val # cell size is the size of one cell that is extracted from the netcdf file. Afterwards this # is cropped to fine_size self.cell_size = opt.fine_size + self.scale_factor with Nc4Dataset(os.path.join(self.root, "dataset.nc4"), "r", format="NETCDF4") as file: check_dimensions(file, self.cell_size, self.scale_factor) cols = file['lon'].size//self.cell_size self.rows = file['lat'].size//self.cell_size self.upscaler = Upscale(size=self.fine_size+2*self.scale_factor, scale_factor=self.scale_factor) # remove 4 40 boxes to create a training set for each block of 40 rows self.lat_lon_list = create_lat_lon_indices(self.rows, cols, self.n_test, self.n_val, seed=opt.seed, phase=phase) # remove 3 years for test and 3 years for val self.time_list = create_time_list(seed=opt.seed, phase=phase) t = len(self.time_list) if phase == 'train': self.length = self.rows * (cols - self.n_test - self.n_val) * t elif phase == 'val': self.length = self.rows * self.n_val * t elif phase == 'test': self.length = self.rows * self.n_test * t pass
def forward(self, z, coarse_pr, coarse_uas, coarse_vas, orog, coarse_psl): # layer 1 hidden_state1 = self.layer1(torch.nn.Upsample(scale_factor=3, mode='nearest')(z)) # layer 2 6x6 layer2_input = [torch.nn.Upsample(scale_factor=2, mode='nearest')(hidden_state1)] if self.use_orog: upscale6 = Upscale(size=48, scale_factor=8, device=self.device) layer2_input.append(upscale6.upscale(orog)) layer2_input += [coarse_pr, coarse_uas, coarse_vas, coarse_psl] hidden_state2 = self.layer2(torch.cat(layer2_input, 1)) # layer 3 12x12 layer3_input = [torch.nn.Upsample(scale_factor=2, mode='nearest')(hidden_state2)] if self.use_orog: upscale12 = Upscale(size=48, scale_factor=4, device=self.device) layer3_input.append(upscale12.upscale(orog)) upsample12 = torch.nn.Upsample(scale_factor=2, mode='nearest') layer3_input += [upsample12(coarse_pr), upsample12(coarse_uas), upsample12(coarse_vas), upsample12(coarse_psl)] hidden_state3 = self.layer3(torch.cat(layer3_input, 1)) # layer 4 24x24 layer4_input = [torch.nn.Upsample(scale_factor=2, mode='nearest')(hidden_state3)] if self.use_orog: upscale24 = Upscale(size=48, scale_factor=2, device=self.device) layer4_input.append(upscale24.upscale(orog)) upsample24 = torch.nn.Upsample(scale_factor=4, mode='nearest') layer4_input += [upsample24(coarse_pr), upsample24(coarse_uas), upsample24(coarse_vas), upsample24(coarse_psl)] hidden_state4 = self.layer4(torch.cat(layer4_input,1)) # layer 5 48x48 layer5_input = [torch.nn.Upsample(scale_factor=2, mode='nearest')(hidden_state4)] if self.use_orog: layer5_input.append(orog) upsample48 = torch.nn.Upsample(scale_factor=self.scale_factor, mode='nearest') layer5_input += [upsample48(coarse_pr), upsample48(coarse_uas), upsample48(coarse_vas), upsample48(coarse_psl)] hidden_state5 = self.layer5(torch.cat(layer5_input,1)) # layer 6 hidden_state6 = self.layer6(hidden_state5) # output layer if self.model == 'gamma_vae': p = self.p_layer(hidden_state6) alpha = self.alpha_layer(hidden_state6) beta = self.beta_layer(hidden_state6) output = {'p': p, 'alpha': alpha, 'beta': beta} elif self.model == 'mse_vae': output = self.output_layer(hidden_state6) else: raise ValueError("model {} is not implemented".format(self.model)) return output
def forward(self, fine_pr, coarse_pr, orog, coarse_uas, coarse_vas, coarse_psl): if not self.regression: # layer 1 input_layer_input = [fine_pr] if self.use_orog: input_layer_input.append(orog) upsample48 = torch.nn.Upsample(scale_factor=self.scale_factor, mode='nearest') input_layer_input += [upsample48(coarse_pr), upsample48(coarse_uas), upsample48(coarse_vas), upsample48(coarse_psl)] h_layer1 = self.h_layer1(torch.cat(input_layer_input, 1)) # layer 2 layer2_input = [h_layer1] if self.use_orog: upscale24 = Upscale(size=48, scale_factor=2, device=self.device) layer2_input.append(upscale24.upscale(orog)) upsample24 = torch.nn.Upsample(scale_factor=4, mode='nearest') layer2_input += [upsample24(coarse_pr), upsample24(coarse_uas), upsample24(coarse_vas), upsample24(coarse_psl)] h_layer2 = self.h_layer2(torch.cat(layer2_input, 1)) # layer 3 layer3_input = [h_layer2] if self.use_orog: upscale12 = Upscale(size=48, scale_factor=4, device=self.device) layer3_input.append(upscale12.upscale(orog)) upsample12 = torch.nn.Upsample(scale_factor=2, mode='nearest') layer3_input += [upsample12(coarse_pr), upsample12(coarse_uas), upsample12(coarse_vas), upsample12(coarse_psl)] h_layer3 = self.h_layer3(torch.cat(layer3_input, 1)) # layer 4 layer4_input = [h_layer3] if self.use_orog: upscale6 = Upscale(size=48, scale_factor=8, device=self.device) layer4_input.append(upscale6.upscale(orog)) layer4_input += [coarse_pr, coarse_uas, coarse_vas, coarse_psl] h_layer4 = self.h_layer4(torch.cat(layer4_input, 1)) # output layer mu = self.mu(h_layer4.view(h_layer4.shape[0], -1)).unsqueeze(-1).unsqueeze(-1) log_var = self.log_var(h_layer4.view(h_layer4.shape[0], -1)).unsqueeze(-1).unsqueeze(-1) else: mu = torch.zeros(fine_pr.shape[0], self.nz, 1, 1, dtype=torch.float, device=self.device) log_var = torch.zeros(fine_pr.shape[0], self.nz, 1, 1, dtype=torch.float, device=self.device) # reparameterization z = self._reparameterize(mu, log_var) # decode recon_pr = self.decode(z=z, coarse_pr=coarse_pr,orog=orog, coarse_uas=coarse_uas, coarse_vas=coarse_vas, coarse_psl=coarse_psl) return recon_pr, mu.view(-1, self.nz), log_var.view(-1, self.nz)
def __init__(self, opt, device): super(SDVAE, self).__init__() # variables self.nz = opt.nz self.input_size = opt.fine_size ** 2 self.upscaler = Upscale(size=48, scale_factor=8, device=device) self.use_orog = not opt.no_orog self.no_dropout = opt.no_dropout self.nf_encoder = opt.nf_encoder self.model = opt.model self.scale_factor = opt.scale_factor self.fine_size = opt.fine_size self.device = device self.regression = opt.regression # dimensions for batch_size=64, nf_encoder=16, fine_size=32, nz=10, orog=True # 64x5x48x48 self.h_layer1 = self._down_conv(in_channels=1 + self.use_orog + 4, out_channels=self.nf_encoder, kernel_size=4, padding=1, stride=2) # 64x20x24x24 self.h_layer2 = self._down_conv(in_channels=self.nf_encoder + self.use_orog + 4, out_channels=self.nf_encoder * 2, kernel_size=4, padding=1, stride=2) # 64x20x12x12 self.h_layer3 = self._down_conv(in_channels=self.nf_encoder*2 + self.use_orog + 4, out_channels=self.nf_encoder * 4, kernel_size=4, padding=1, stride=2) # 64x35x6x6 self.h_layer4 = self._down_conv(in_channels=4 * self.nf_encoder + self.use_orog + 4 , out_channels=self.nf_encoder * 8, kernel_size=4, padding=1, stride=2) # 64x48x3x3 # mu self.mu = nn.Sequential(nn.Linear(in_features=self.nf_encoder * 8 * 9, out_features=self.nz)) # 64x10x1x1 # log_var self.log_var = nn.Sequential(nn.Linear(in_features=self.nf_encoder * 8 * 9, out_features=self.nz)) # 64x10x1x1 self.decode = Decoder(opt, device)
def main(): # set variables, create directories opt = BaseOptions().parse() n_samples = opt.n_samples device = torch.device("cuda" if len(opt.gpu_ids) > 0 else "cpu") upscaler = Upscale(size=48, scale_factor=opt.scale_factor, device=device) load_root = os.path.join('checkpoints', opt.name) load_epoch = opt.load_epoch if opt.load_epoch >= 0 else 'latest' load_name = "epoch_{}.pth".format(load_epoch) load_dir = os.path.join(load_root, load_name) outdir = os.path.join(opt.results_dir, opt.name, 'times_%s_%s' % (opt.phase, load_epoch)) if not os.path.exists(outdir): os.makedirs(outdir) climate_data = ClimateDataset(opt=opt, phase=opt.phase) # large_cell = 48x48, cell = 40x40, small_cell = 32x32 cell = 40 # load the model model = SDVAE(opt=opt, device=device).to(device) model.load_state_dict(torch.load(load_dir, map_location='cpu')) model.eval() input_dataset = Dataset(os.path.join(opt.dataroot, 'dataset.nc4'), "r", format="NETCDF4") basename = "val.nc4" print("start validation") # create output file output_dataset_path = os.path.join(outdir, basename) output_dataset = Dataset(output_dataset_path, "w", format="NETCDF4") output_dataset.setncatts( {k: input_dataset.getncattr(k) for k in input_dataset.ncattrs()}) # add own metadata output_dataset.creators = "Simon Treu (EDGAN, [email protected])\n" + output_dataset.creators output_dataset.history = datetime.date.today().isoformat( ) + " Added Downscaled images in python with pix2pix-edgan\n" + output_dataset.history output_dataset.createDimension("time", None) output_dataset.createDimension("lon", 720) output_dataset.createDimension("lat", 120) # Copy variables for v_name, varin in input_dataset.variables.items(): outVar = output_dataset.createVariable(v_name, varin.datatype, varin.dimensions) # Copy variable attributes outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()}) # Create variable for downscaling for k in range(n_samples): downscaled_pr = output_dataset.createVariable( "downscaled_pr_{}".format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) downscaled_pr.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) downscaled_pr.standard_name += '_downscaled' downscaled_pr.long_name += '_downscaled' downscaled_pr.comment = 'downscaled ' + downscaled_pr.comment if opt.model == 'gamma_vae': # p p = output_dataset.createVariable('p_{}'.format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) p.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) p.standard_name += '_p' p.long_name += '_p' p.comment = 'p ' + p.comment # alpha alpha = output_dataset.createVariable( 'alpha_{}'.format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) alpha.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) alpha.standard_name += '_alpha' alpha.long_name += '_alpha' alpha.comment = 'alpha ' + alpha.comment # beta beta = output_dataset.createVariable( 'beta_{}'.format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) beta.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) beta.standard_name += '_beta' beta.long_name += '_beta' beta.comment = 'beta ' + beta.comment # mean_downscaled_pr_{} mean_downscaled_pr = output_dataset.createVariable( 'mean_downscaled_pr_{}'.format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) mean_downscaled_pr.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) mean_downscaled_pr.standard_name += '_mean_downscaled_pr' mean_downscaled_pr.long_name += '_mean_downscaled_pr' mean_downscaled_pr.comment = 'mean_downscaled_pr ' + mean_downscaled_pr.comment bilinear_downscaled_pr = output_dataset.createVariable( 'bilinear_downscaled_pr', output_dataset['pr'].datatype, output_dataset['pr'].dimensions) bilinear_downscaled_pr.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) bilinear_downscaled_pr.standard_name += '_bilinear_downscaled' bilinear_downscaled_pr.long_name += '_bilinear_downscaled' bilinear_downscaled_pr.comment = 'bilinear_downscaled ' + bilinear_downscaled_pr.comment # set variable values output_dataset['lat'][:] = input_dataset['lat'][8:-8] output_dataset['lon'][:] = input_dataset['lon'][:] output_dataset['time'][:] = input_dataset['time'][climate_data.time_list] output_dataset['orog'][:] = input_dataset['orog'][8:-8, :] output_dataset['pr'][:] = input_dataset['pr'][climate_data.time_list, 8:-8, :] output_dataset['uas'][:] = input_dataset['uas'][climate_data.time_list, 8:-8, :] output_dataset['vas'][:] = input_dataset['vas'][climate_data.time_list, 8:-8, :] output_dataset['psl'][:] = input_dataset['psl'][climate_data.time_list, 8:-8, :] for idx_lat in range(3): for idx_lon in range(18): # lat with index 0 is 34 N. large_cell_lats = [ i for i in range(idx_lat * 40 + 4, (idx_lat + 1) * 40 + 12) ] lats = [i for i in range(idx_lat * 40, (idx_lat + 1) * 40)] # longitudes might cross the prime meridian large_cell_lons = [ i % 720 for i in range(idx_lon * 40 - 4, (idx_lon + 1) * 40 + 4) ] lons = [i % 720 for i in range(idx_lon * 40, (idx_lon + 1) * 40)] pr_tensor = torch.tensor( input_dataset['pr'][climate_data.time_list, large_cell_lats, large_cell_lons], dtype=torch.float32, device=device).unsqueeze(1) orog_tensor = torch.tensor( input_dataset['orog'][large_cell_lats, large_cell_lons], dtype=torch.float32, device=device).expand(len(climate_data.time_list), 1, 48, 48) uas_tensor = torch.tensor( input_dataset['uas'][climate_data.time_list, large_cell_lats, large_cell_lons], dtype=torch.float32, device=device).unsqueeze(1) vas_tensor = torch.tensor( input_dataset['vas'][climate_data.time_list, large_cell_lats, large_cell_lons], dtype=torch.float32, device=device).unsqueeze(1) psl_tensor = torch.tensor( input_dataset['psl'][climate_data.time_list, large_cell_lats, large_cell_lons], dtype=torch.float32, device=device).unsqueeze(1) coarse_pr = upscaler.upscale(pr_tensor) coarse_uas = upscaler.upscale(uas_tensor) coarse_vas = upscaler.upscale(vas_tensor) coarse_psl = upscaler.upscale(psl_tensor) for k in range(n_samples): with torch.no_grad(): recon_pr = model.decode(z=torch.randn(len( climate_data.time_list), opt.nz, 1, 1, device=device), coarse_pr=coarse_pr, coarse_uas=coarse_uas, coarse_vas=coarse_vas, orog=orog_tensor, coarse_psl=coarse_psl) if opt.model == "mse_vae": output_dataset['downscaled_pr_{}'.format( k)][:, lats, lons] = recon_pr[:, 0, 4:-4, 4:-4] elif opt.model == "gamma_vae": output_dataset['downscaled_pr_{}'.format(k)][:, lats, lons] = \ (torch.distributions.bernoulli.Bernoulli(recon_pr['p']).sample() * torch.distributions.gamma.Gamma(recon_pr['alpha'],1/recon_pr['beta']).sample())[:, 0,4:-4,4:-4] output_dataset['mean_downscaled_pr_{}'.format(k)][:,lats,lons] = \ torch.nn.Threshold(0.035807043601739474, 0)(recon_pr['p'] * recon_pr['alpha'] * recon_pr['beta'])[:, 0,4:-4,4:-4] output_dataset['p_{}'.format( k)][:, lats, lons] = recon_pr['p'][:, 0, 4:-4, 4:-4] output_dataset['alpha_{}'.format( k)][:, lats, lons] = recon_pr['alpha'][:, 0, 4:-4, 4:-4] output_dataset['beta_{}'.format( k)][:, lats, lons] = recon_pr['beta'][:, 0, 4:-4, 4:-4] #todo don't hardcode threshold else: raise ValueError("model {} is not implemented".format( opt.model)) bilinear_pr = torch.nn.functional.upsample( coarse_pr, scale_factor=opt.scale_factor, mode='bilinear', align_corners=True) output_dataset['bilinear_downscaled_pr'][:, lats, lons] = bilinear_pr[:, 0, 4:-4, 4:-4] print('Progress = {:>5.1f} %'.format( (idx_lat * 3 + idx_lon + 1) * 100 / (3 * 18))) output_dataset.close() input_dataset.close()
def main(): # set variables, create directories opt = BaseOptions().parse() n_samples = opt.n_samples device = torch.device("cuda" if len(opt.gpu_ids) > 0 else "cpu") upscaler = Upscale(size=48, scale_factor=opt.scale_factor, device=device) load_root = os.path.join('checkpoints', opt.name) load_epoch = opt.load_epoch if opt.load_epoch >= 0 else 'latest' load_name = "epoch_{}.pth".format(load_epoch) load_dir = os.path.join(load_root, load_name) outdir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, load_epoch)) if not os.path.exists(outdir): os.makedirs(outdir) climate_data = ClimateDataset(opt=opt, phase=opt.phase) # large_cell = 48x48, cell = 40x40, small_cell = 32x32 large_cell = opt.fine_size + 2 * opt.scale_factor # load the model model = SDVAE(opt=opt, device=device).to(device) model.load_state_dict(torch.load(load_dir, map_location='cpu')) model.eval() # Iterate val cells and compute #n_samples reconstructions. index = 0 # read input file input_dataset = Dataset(os.path.join(opt.dataroot, 'dataset.nc4'), "r", format="NETCDF4") for idx_lat in range(climate_data.rows): for idx_lon in climate_data.lat_lon_list[idx_lat]: # calculate upper left index for cell with boundary values to downscale anchor_lat = idx_lat * climate_data.cell_size + climate_data.scale_factor anchor_lon = idx_lon * climate_data.cell_size # select indices for a 48 x 48 box around the 32 x 32 box to be downscaled (with boundary values) large_cell_lats = [ i for i in range( anchor_lat - climate_data.scale_factor, anchor_lat + climate_data.fine_size + climate_data.scale_factor) ] # longitudes might cross the prime meridian large_cell_lons = [ i % 720 for i in range( anchor_lon - climate_data.scale_factor, anchor_lon + climate_data.fine_size + climate_data.scale_factor) ] # create output path basename = "val.lat{}_lon{}.nc4".format(anchor_lat, anchor_lon) print("test file nr. {} name: {}".format(index, basename)) # create output file output_dataset_path = os.path.join(outdir, basename) output_dataset = Dataset(output_dataset_path, "w", format="NETCDF4") output_dataset.setncatts({ k: input_dataset.getncattr(k) for k in input_dataset.ncattrs() }) # add own metadata output_dataset.creators = "Simon Treu (EDGAN, [email protected])\n" + output_dataset.creators output_dataset.history = datetime.date.today().isoformat( ) + " Added Downscaled images in python with pix2pix-edgan\n" + output_dataset.history output_dataset.createDimension("time", None) output_dataset.createDimension("lon", large_cell) output_dataset.createDimension("lat", large_cell) # Copy variables for v_name, varin in input_dataset.variables.items(): outVar = output_dataset.createVariable(v_name, varin.datatype, varin.dimensions) # Copy variable attributes outVar.setncatts( {k: varin.getncattr(k) for k in varin.ncattrs()}) # Create variable for downscaling for k in range(n_samples): downscaled_pr = output_dataset.createVariable( "downscaled_pr_{}".format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) downscaled_pr.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) downscaled_pr.standard_name += '_downscaled' downscaled_pr.long_name += '_downscaled' downscaled_pr.comment = 'downscaled ' + downscaled_pr.comment if opt.model == 'gamma_vae': # p p = output_dataset.createVariable( 'p_{}'.format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) p.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) p.standard_name += '_p' p.long_name += '_p' p.comment = 'p ' + p.comment # alpha alpha = output_dataset.createVariable( 'alpha_{}'.format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) alpha.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) alpha.standard_name += '_alpha' alpha.long_name += '_alpha' alpha.comment = 'alpha ' + alpha.comment # beta beta = output_dataset.createVariable( 'beta_{}'.format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) beta.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) beta.standard_name += '_beta' beta.long_name += '_beta' beta.comment = 'beta ' + beta.comment # mean_downscaled_pr_{} mean_downscaled_pr = output_dataset.createVariable( 'mean_downscaled_pr_{}'.format(k), output_dataset['pr'].datatype, output_dataset['pr'].dimensions) mean_downscaled_pr.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) mean_downscaled_pr.standard_name += '_mean_downscaled_pr' mean_downscaled_pr.long_name += '_mean_downscaled_pr' mean_downscaled_pr.comment = 'mean_downscaled_pr ' + mean_downscaled_pr.comment bilinear_downscaled_pr = output_dataset.createVariable( 'bilinear_downscaled_pr', output_dataset['pr'].datatype, output_dataset['pr'].dimensions) bilinear_downscaled_pr.setncatts({ k: output_dataset['pr'].getncattr(k) for k in output_dataset['pr'].ncattrs() }) bilinear_downscaled_pr.standard_name += '_bilinear_downscaled' bilinear_downscaled_pr.long_name += '_bilinear_downscaled' bilinear_downscaled_pr.comment = 'bilinear_downscaled ' + bilinear_downscaled_pr.comment # set variable values output_dataset['lat'][:] = input_dataset['lat'][large_cell_lats] output_dataset['lon'][:] = input_dataset['lon'][large_cell_lons] output_dataset['time'][:] = input_dataset['time'][:] output_dataset['orog'][:] = input_dataset['orog'][large_cell_lats, large_cell_lons] output_dataset['pr'][:] = input_dataset['pr'][:, large_cell_lats, large_cell_lons] for k in range(n_samples): output_dataset['downscaled_pr_{}'.format( k)][:] = input_dataset['pr'][:, large_cell_lats, large_cell_lons] output_dataset['bilinear_downscaled_pr'][:] = input_dataset[ 'pr'][:, large_cell_lats, large_cell_lons] output_dataset['uas'][:] = input_dataset['uas'][:, large_cell_lats, large_cell_lons] output_dataset['vas'][:] = input_dataset['vas'][:, large_cell_lats, large_cell_lons] output_dataset['psl'][:] = input_dataset['psl'][:, large_cell_lats, large_cell_lons] # read out the variables similar to construct_datasets.py pr = output_dataset['pr'][:] uas = output_dataset['uas'][:] vas = output_dataset['vas'][:] psl = output_dataset['psl'][:] orog = output_dataset['orog'][:] times = pr.shape[0] for t in range(times): pr_tensor = torch.tensor(pr[t, :, :], dtype=torch.float32, device=device) orog_tensor = torch.tensor( orog[:, :], dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0) uas_tensor = torch.tensor(uas[t, :, :], dtype=torch.float32, device=device) vas_tensor = torch.tensor(vas[t, :, :], dtype=torch.float32, device=device) psl_tensor = torch.tensor(psl[t, :, :], dtype=torch.float32, device=device) coarse_pr = upscaler.upscale(pr_tensor).unsqueeze(0).unsqueeze( 0) coarse_uas = upscaler.upscale(uas_tensor).unsqueeze( 0).unsqueeze(0) coarse_vas = upscaler.upscale(vas_tensor).unsqueeze( 0).unsqueeze(0) coarse_psl = upscaler.upscale(psl_tensor).unsqueeze( 0).unsqueeze(0) for k in range(n_samples): with torch.no_grad(): recon_pr = model.decode(z=torch.randn(1, opt.nz, 1, 1, device=device), coarse_pr=coarse_pr, coarse_uas=coarse_uas, coarse_vas=coarse_vas, orog=orog_tensor, coarse_psl=coarse_psl) if opt.model == "mse_vae": output_dataset['downscaled_pr_{}'.format(k)][ t, :, :] = recon_pr elif opt.model == "gamma_vae": output_dataset['downscaled_pr_{}'.format(k)][t, :, :] = \ (torch.distributions.bernoulli.Bernoulli(recon_pr['p']).sample() * torch.distributions.gamma.Gamma(recon_pr['alpha'],1/recon_pr['beta']).sample())[0, 0,:,:] output_dataset['mean_downscaled_pr_{}'.format(k)][t,:,:] = \ torch.nn.Threshold(0.035807043601739474, 0)(recon_pr['p'] * recon_pr['alpha'] * recon_pr['beta']) output_dataset['p_{}'.format(k)][ t, :, :] = recon_pr['p'] output_dataset['alpha_{}'.format(k)][ t, :, :] = recon_pr['alpha'] output_dataset['beta_{}'.format(k)][ t, :, :] = recon_pr['beta'] #todo don't hardcode threshold else: raise ValueError("model {} is not implemented".format( opt.model)) bilinear_pr = torch.nn.functional.upsample( coarse_pr, scale_factor=opt.scale_factor, mode='bilinear', align_corners=True) output_dataset['bilinear_downscaled_pr'][ t, :, :] = bilinear_pr[0, 0, :, :] # read out the variables # create reanalysis result (croped to 32*32) # compute downscaled images # save them in result dataset output_dataset.close() index += 1 input_dataset.close()
class SDVAE(nn.Module): def __init__(self, opt, device): super(SDVAE, self).__init__() # variables self.nz = opt.nz self.input_size = opt.fine_size ** 2 self.upscaler = Upscale(size=48, scale_factor=8, device=device) self.use_orog = not opt.no_orog self.no_dropout = opt.no_dropout self.nf_encoder = opt.nf_encoder self.model = opt.model self.scale_factor = opt.scale_factor self.fine_size = opt.fine_size self.device = device self.regression = opt.regression # dimensions for batch_size=64, nf_encoder=16, fine_size=32, nz=10, orog=True # 64x5x48x48 self.h_layer1 = self._down_conv(in_channels=1 + self.use_orog + 4, out_channels=self.nf_encoder, kernel_size=4, padding=1, stride=2) # 64x20x24x24 self.h_layer2 = self._down_conv(in_channels=self.nf_encoder + self.use_orog + 4, out_channels=self.nf_encoder * 2, kernel_size=4, padding=1, stride=2) # 64x20x12x12 self.h_layer3 = self._down_conv(in_channels=self.nf_encoder*2 + self.use_orog + 4, out_channels=self.nf_encoder * 4, kernel_size=4, padding=1, stride=2) # 64x35x6x6 self.h_layer4 = self._down_conv(in_channels=4 * self.nf_encoder + self.use_orog + 4 , out_channels=self.nf_encoder * 8, kernel_size=4, padding=1, stride=2) # 64x48x3x3 # mu self.mu = nn.Sequential(nn.Linear(in_features=self.nf_encoder * 8 * 9, out_features=self.nz)) # 64x10x1x1 # log_var self.log_var = nn.Sequential(nn.Linear(in_features=self.nf_encoder * 8 * 9, out_features=self.nz)) # 64x10x1x1 self.decode = Decoder(opt, device) def forward(self, fine_pr, coarse_pr, orog, coarse_uas, coarse_vas, coarse_psl): if not self.regression: # layer 1 input_layer_input = [fine_pr] if self.use_orog: input_layer_input.append(orog) upsample48 = torch.nn.Upsample(scale_factor=self.scale_factor, mode='nearest') input_layer_input += [upsample48(coarse_pr), upsample48(coarse_uas), upsample48(coarse_vas), upsample48(coarse_psl)] h_layer1 = self.h_layer1(torch.cat(input_layer_input, 1)) # layer 2 layer2_input = [h_layer1] if self.use_orog: upscale24 = Upscale(size=48, scale_factor=2, device=self.device) layer2_input.append(upscale24.upscale(orog)) upsample24 = torch.nn.Upsample(scale_factor=4, mode='nearest') layer2_input += [upsample24(coarse_pr), upsample24(coarse_uas), upsample24(coarse_vas), upsample24(coarse_psl)] h_layer2 = self.h_layer2(torch.cat(layer2_input, 1)) # layer 3 layer3_input = [h_layer2] if self.use_orog: upscale12 = Upscale(size=48, scale_factor=4, device=self.device) layer3_input.append(upscale12.upscale(orog)) upsample12 = torch.nn.Upsample(scale_factor=2, mode='nearest') layer3_input += [upsample12(coarse_pr), upsample12(coarse_uas), upsample12(coarse_vas), upsample12(coarse_psl)] h_layer3 = self.h_layer3(torch.cat(layer3_input, 1)) # layer 4 layer4_input = [h_layer3] if self.use_orog: upscale6 = Upscale(size=48, scale_factor=8, device=self.device) layer4_input.append(upscale6.upscale(orog)) layer4_input += [coarse_pr, coarse_uas, coarse_vas, coarse_psl] h_layer4 = self.h_layer4(torch.cat(layer4_input, 1)) # output layer mu = self.mu(h_layer4.view(h_layer4.shape[0], -1)).unsqueeze(-1).unsqueeze(-1) log_var = self.log_var(h_layer4.view(h_layer4.shape[0], -1)).unsqueeze(-1).unsqueeze(-1) else: mu = torch.zeros(fine_pr.shape[0], self.nz, 1, 1, dtype=torch.float, device=self.device) log_var = torch.zeros(fine_pr.shape[0], self.nz, 1, 1, dtype=torch.float, device=self.device) # reparameterization z = self._reparameterize(mu, log_var) # decode recon_pr = self.decode(z=z, coarse_pr=coarse_pr,orog=orog, coarse_uas=coarse_uas, coarse_vas=coarse_vas, coarse_psl=coarse_psl) return recon_pr, mu.view(-1, self.nz), log_var.view(-1, self.nz) def loss_function(self, recon_x, x, mu, log_var, coarse_pr,): # negative log predictive density if self.model == 'gamma_vae': nlpd = self._neg_log_gamma_likelihood(x, recon_x['alpha'], recon_x['beta'], recon_x['p']) elif self.model == 'mse_vae': nlpd = nn.functional.mse_loss(recon_x, x, size_average=False)/x.shape[0] # Kullback-Leibler Divergence kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(),1)) loss = kld + nlpd # cycle loss # not added to the total loss if self.model == 'gamma_vae': coarse_recon = self.upscaler.upscale(recon_x['p'] * recon_x['alpha'] * recon_x['beta']) elif self.model == 'mse_vae': coarse_recon = self.upscaler.upscale(recon_x) cycle_loss = nn.functional.mse_loss(coarse_pr, coarse_recon, size_average=True) return nlpd, kld, cycle_loss, loss def _reparameterize(self, mu, log_var): if self.training: std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return eps.mul(std).add_(mu) else: return mu def _neg_log_gamma_likelihood(self, x, alpha, beta, p): result = torch.zeros(1) if (x > 0).any(): result = - torch.sum(torch.log(p[x > 0]) + (alpha[x > 0] - 1) * torch.log(x[x > 0]) - alpha[x > 0] * torch.log(beta[x > 0]) - x[x > 0] / beta[x > 0] - torch.lgamma(alpha[x > 0]) ) if (x == 0).any(): result -= torch.sum(torch.log(1 - p[x == 0]) + 0 * alpha[x == 0] + 0 * beta[x == 0]) return result/x.shape[0] # mean over batch size def _down_conv(self, in_channels, out_channels, kernel_size, padding, stride): if self.no_dropout: return nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride), nn.BatchNorm2d(out_channels), nn.ReLU()) else: return nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.Dropout())
class ClimateDataset(Dataset): def __init__(self, opt, phase): self.root = opt.dataroot self.scale_factor = opt.scale_factor self.fine_size = opt.fine_size self.n_test = opt.n_test self.n_val = opt.n_val # cell size is the size of one cell that is extracted from the netcdf file. Afterwards this # is cropped to fine_size self.cell_size = opt.fine_size + self.scale_factor with Nc4Dataset(os.path.join(self.root, "dataset.nc4"), "r", format="NETCDF4") as file: check_dimensions(file, self.cell_size, self.scale_factor) cols = file['lon'].size//self.cell_size self.rows = file['lat'].size//self.cell_size self.upscaler = Upscale(size=self.fine_size+2*self.scale_factor, scale_factor=self.scale_factor) # remove 4 40 boxes to create a training set for each block of 40 rows self.lat_lon_list = create_lat_lon_indices(self.rows, cols, self.n_test, self.n_val, seed=opt.seed, phase=phase) # remove 3 years for test and 3 years for val self.time_list = create_time_list(seed=opt.seed, phase=phase) t = len(self.time_list) if phase == 'train': self.length = self.rows * (cols - self.n_test - self.n_val) * t elif phase == 'val': self.length = self.rows * self.n_val * t elif phase == 'test': self.length = self.rows * self.n_test * t pass def __len__(self): return self.length def __getitem__(self, index): start_time = time.time() # todo remove timing # ++ calculate lat lon and time from index ++ # s_lat = len(self.lat_lon_list) s_lon = len(self.lat_lon_list[0]) t_idx = index // (s_lat * s_lon) lat = index % (s_lat * s_lon) // s_lon lon = index % s_lon # -------------------------------------------------------------------------------------------------------------- # calculate a random offset from the upper left corner to crop the box w_offset = random.randint(0, self.scale_factor-1) h_offset = random.randint(0, self.scale_factor-1) # -------------------------------------------------------------------------------------------------------------- # calculate the lat and lon indices of the upper left corner in the netcdf file # add scale factor because first 8 pixels are only for boundary conditions # --> for lat=0 the index in the netcdf file is 8. anchor_lat = lat * self.cell_size + w_offset + self.scale_factor anchor_lon = self.lat_lon_list[lat][lon] * self.cell_size + h_offset # select indices for a 48 x 48 box around the 32 x 32 box to be downscaled (with boundary values) boundary_lats = [i for i in range(anchor_lat-self.scale_factor, anchor_lat+self.fine_size+self.scale_factor)] # longitudes might cross the prime meridian boundary_lons = [i % 720 for i in range(anchor_lon-self.scale_factor, anchor_lon+self.fine_size+self.scale_factor)] t=self.time_list[t_idx] # -------------------------------------------------------------------------------------------------------------- # ++ read data ++ # with Nc4Dataset(os.path.join(self.root, "dataset.nc4"), "r", format="NETCDF4") as file: pr = torch.tensor(file['pr'][t, boundary_lats, boundary_lons], dtype=torch.float) orog = torch.tensor(file['orog'][boundary_lats, boundary_lons], dtype=torch.float) uas = torch.tensor(file['uas'][t, boundary_lats, boundary_lons], dtype=torch.float) vas = torch.tensor(file['vas'][t, boundary_lats, boundary_lons], dtype=torch.float) psl = torch.tensor(file['psl'][t, boundary_lats, boundary_lons], dtype=torch.float) # -------------------------------------------------------------------------------------------------------------- coarse_pr = self.upscaler.upscale(pr) coarse_uas = self.upscaler.upscale(uas) coarse_vas = self.upscaler.upscale(vas) coarse_psl = self.upscaler.upscale(psl) # bring all into shape [C,W,H] (Channels, With, Height) pr.unsqueeze_(0) orog.unsqueeze_(0) coarse_pr.unsqueeze_(0) coarse_uas.unsqueeze_(0) coarse_vas.unsqueeze_(0) coarse_psl.unsqueeze_(0) end_time = time.time() - start_time return {'fine_pr': pr, 'coarse_pr': coarse_pr, 'orog': orog, 'coarse_uas': coarse_uas, 'coarse_vas': coarse_vas, 'coarse_psl': coarse_psl, 'time': end_time}