Beispiel #1
0
    def __init__(self, opt, phase):
        self.root = opt.dataroot
        self.scale_factor = opt.scale_factor
        self.fine_size = opt.fine_size
        self.n_test = opt.n_test
        self.n_val = opt.n_val
        # cell size is the size of one cell that is extracted from the netcdf file. Afterwards this
        # is cropped to fine_size
        self.cell_size = opt.fine_size + self.scale_factor
        with Nc4Dataset(os.path.join(self.root, "dataset.nc4"), "r", format="NETCDF4") as file:
            check_dimensions(file, self.cell_size, self.scale_factor)
            cols = file['lon'].size//self.cell_size
            self.rows = file['lat'].size//self.cell_size

        self.upscaler = Upscale(size=self.fine_size+2*self.scale_factor, scale_factor=self.scale_factor)

        # remove 4 40 boxes to create a training set for each block of 40 rows
        self.lat_lon_list = create_lat_lon_indices(self.rows, cols, self.n_test, self.n_val,
                                                   seed=opt.seed, phase=phase)
        # remove 3 years for test and 3 years for val
        self.time_list = create_time_list(seed=opt.seed, phase=phase)

        t = len(self.time_list)
        if phase == 'train':
            self.length = self.rows * (cols - self.n_test - self.n_val) * t
        elif phase == 'val':
            self.length = self.rows * self.n_val * t
        elif phase == 'test':
            self.length = self.rows * self.n_test * t
        pass
Beispiel #2
0
    def forward(self, z, coarse_pr,
                coarse_uas, coarse_vas, orog, coarse_psl):
        # layer 1
        hidden_state1 = self.layer1(torch.nn.Upsample(scale_factor=3, mode='nearest')(z))

        # layer 2 6x6
        layer2_input = [torch.nn.Upsample(scale_factor=2, mode='nearest')(hidden_state1)]
        if self.use_orog:
            upscale6 = Upscale(size=48, scale_factor=8, device=self.device)
            layer2_input.append(upscale6.upscale(orog))
        layer2_input += [coarse_pr, coarse_uas, coarse_vas, coarse_psl]
        hidden_state2 = self.layer2(torch.cat(layer2_input, 1))

        # layer 3 12x12
        layer3_input = [torch.nn.Upsample(scale_factor=2, mode='nearest')(hidden_state2)]
        if self.use_orog:
            upscale12 = Upscale(size=48, scale_factor=4, device=self.device)
            layer3_input.append(upscale12.upscale(orog))
        upsample12 = torch.nn.Upsample(scale_factor=2, mode='nearest')
        layer3_input += [upsample12(coarse_pr), upsample12(coarse_uas),
                         upsample12(coarse_vas), upsample12(coarse_psl)]
        hidden_state3 = self.layer3(torch.cat(layer3_input, 1))

        # layer 4 24x24
        layer4_input = [torch.nn.Upsample(scale_factor=2, mode='nearest')(hidden_state3)]
        if self.use_orog:
            upscale24 = Upscale(size=48, scale_factor=2, device=self.device)
            layer4_input.append(upscale24.upscale(orog))
        upsample24 = torch.nn.Upsample(scale_factor=4, mode='nearest')
        layer4_input += [upsample24(coarse_pr), upsample24(coarse_uas),
                         upsample24(coarse_vas), upsample24(coarse_psl)]
        hidden_state4 = self.layer4(torch.cat(layer4_input,1))

        # layer 5 48x48
        layer5_input = [torch.nn.Upsample(scale_factor=2, mode='nearest')(hidden_state4)]
        if self.use_orog:
            layer5_input.append(orog)
        upsample48 = torch.nn.Upsample(scale_factor=self.scale_factor, mode='nearest')
        layer5_input += [upsample48(coarse_pr), upsample48(coarse_uas),
                         upsample48(coarse_vas), upsample48(coarse_psl)]
        hidden_state5 = self.layer5(torch.cat(layer5_input,1))

        # layer 6
        hidden_state6 = self.layer6(hidden_state5)

        # output layer
        if self.model == 'gamma_vae':
            p = self.p_layer(hidden_state6)
            alpha = self.alpha_layer(hidden_state6)
            beta = self.beta_layer(hidden_state6)
            output = {'p': p, 'alpha': alpha, 'beta': beta}
        elif self.model == 'mse_vae':
            output = self.output_layer(hidden_state6)
        else:
            raise ValueError("model {} is not implemented".format(self.model))
        return output
Beispiel #3
0
    def forward(self, fine_pr, coarse_pr, orog, coarse_uas, coarse_vas, coarse_psl):
        if not self.regression:
            # layer 1
            input_layer_input = [fine_pr]
            if self.use_orog:
                input_layer_input.append(orog)

            upsample48 = torch.nn.Upsample(scale_factor=self.scale_factor, mode='nearest')
            input_layer_input += [upsample48(coarse_pr), upsample48(coarse_uas),
                                  upsample48(coarse_vas), upsample48(coarse_psl)]
            h_layer1 = self.h_layer1(torch.cat(input_layer_input, 1))

            # layer 2
            layer2_input = [h_layer1]
            if self.use_orog:
                upscale24 = Upscale(size=48, scale_factor=2, device=self.device)
                layer2_input.append(upscale24.upscale(orog))
            upsample24 = torch.nn.Upsample(scale_factor=4, mode='nearest')
            layer2_input += [upsample24(coarse_pr), upsample24(coarse_uas),
                             upsample24(coarse_vas), upsample24(coarse_psl)]
            h_layer2 = self.h_layer2(torch.cat(layer2_input, 1))

            # layer 3
            layer3_input = [h_layer2]
            if self.use_orog:
                upscale12 = Upscale(size=48, scale_factor=4, device=self.device)
                layer3_input.append(upscale12.upscale(orog))
            upsample12 = torch.nn.Upsample(scale_factor=2, mode='nearest')
            layer3_input += [upsample12(coarse_pr), upsample12(coarse_uas),
                             upsample12(coarse_vas), upsample12(coarse_psl)]
            h_layer3 = self.h_layer3(torch.cat(layer3_input, 1))

            # layer 4
            layer4_input = [h_layer3]
            if self.use_orog:
                upscale6 = Upscale(size=48, scale_factor=8, device=self.device)
                layer4_input.append(upscale6.upscale(orog))
            layer4_input += [coarse_pr, coarse_uas, coarse_vas, coarse_psl]
            h_layer4 = self.h_layer4(torch.cat(layer4_input, 1))

            # output layer
            mu = self.mu(h_layer4.view(h_layer4.shape[0], -1)).unsqueeze(-1).unsqueeze(-1)
            log_var = self.log_var(h_layer4.view(h_layer4.shape[0], -1)).unsqueeze(-1).unsqueeze(-1)
        else:
            mu = torch.zeros(fine_pr.shape[0], self.nz, 1, 1, dtype=torch.float, device=self.device)
            log_var = torch.zeros(fine_pr.shape[0], self.nz, 1, 1, dtype=torch.float, device=self.device)
        # reparameterization
        z = self._reparameterize(mu, log_var)

        # decode
        recon_pr = self.decode(z=z, coarse_pr=coarse_pr,orog=orog,
                               coarse_uas=coarse_uas, coarse_vas=coarse_vas, coarse_psl=coarse_psl)
        return recon_pr, mu.view(-1, self.nz), log_var.view(-1, self.nz)
Beispiel #4
0
    def __init__(self, opt, device):
        super(SDVAE, self).__init__()
        # variables
        self.nz = opt.nz
        self.input_size = opt.fine_size ** 2
        self.upscaler = Upscale(size=48, scale_factor=8, device=device)
        self.use_orog = not opt.no_orog
        self.no_dropout = opt.no_dropout
        self.nf_encoder = opt.nf_encoder
        self.model = opt.model
        self.scale_factor = opt.scale_factor
        self.fine_size = opt.fine_size
        self.device = device
        self.regression = opt.regression


        # dimensions for batch_size=64, nf_encoder=16, fine_size=32, nz=10, orog=True
        # 64x5x48x48
        self.h_layer1 = self._down_conv(in_channels=1 + self.use_orog + 4,
                                        out_channels=self.nf_encoder,
                                        kernel_size=4, padding=1, stride=2)
        # 64x20x24x24
        self.h_layer2 = self._down_conv(in_channels=self.nf_encoder + self.use_orog + 4,
                                        out_channels=self.nf_encoder * 2,
                                        kernel_size=4, padding=1, stride=2)
        # 64x20x12x12
        self.h_layer3 = self._down_conv(in_channels=self.nf_encoder*2 + self.use_orog + 4,
                                        out_channels=self.nf_encoder * 4,
                                        kernel_size=4, padding=1, stride=2)
        # 64x35x6x6
        self.h_layer4 = self._down_conv(in_channels=4 * self.nf_encoder + self.use_orog + 4
                                        , out_channels=self.nf_encoder * 8,
                                        kernel_size=4, padding=1, stride=2)
        # 64x48x3x3

        # mu
        self.mu = nn.Sequential(nn.Linear(in_features=self.nf_encoder * 8 * 9, out_features=self.nz))
        # 64x10x1x1

        # log_var
        self.log_var = nn.Sequential(nn.Linear(in_features=self.nf_encoder * 8 * 9, out_features=self.nz))
        # 64x10x1x1

        self.decode = Decoder(opt, device)
Beispiel #5
0
def main():
    # set variables, create directories
    opt = BaseOptions().parse()
    n_samples = opt.n_samples
    device = torch.device("cuda" if len(opt.gpu_ids) > 0 else "cpu")
    upscaler = Upscale(size=48, scale_factor=opt.scale_factor, device=device)
    load_root = os.path.join('checkpoints', opt.name)
    load_epoch = opt.load_epoch if opt.load_epoch >= 0 else 'latest'
    load_name = "epoch_{}.pth".format(load_epoch)
    load_dir = os.path.join(load_root, load_name)
    outdir = os.path.join(opt.results_dir, opt.name,
                          'times_%s_%s' % (opt.phase, load_epoch))
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    climate_data = ClimateDataset(opt=opt, phase=opt.phase)

    # large_cell = 48x48, cell = 40x40, small_cell = 32x32
    cell = 40

    # load the model
    model = SDVAE(opt=opt, device=device).to(device)
    model.load_state_dict(torch.load(load_dir, map_location='cpu'))
    model.eval()

    input_dataset = Dataset(os.path.join(opt.dataroot, 'dataset.nc4'),
                            "r",
                            format="NETCDF4")

    basename = "val.nc4"
    print("start validation")

    # create output file
    output_dataset_path = os.path.join(outdir, basename)
    output_dataset = Dataset(output_dataset_path, "w", format="NETCDF4")
    output_dataset.setncatts(
        {k: input_dataset.getncattr(k)
         for k in input_dataset.ncattrs()})
    # add own metadata
    output_dataset.creators = "Simon Treu (EDGAN, [email protected])\n" + output_dataset.creators
    output_dataset.history = datetime.date.today().isoformat(
    ) + " Added Downscaled images in python with pix2pix-edgan\n" + output_dataset.history

    output_dataset.createDimension("time", None)
    output_dataset.createDimension("lon", 720)
    output_dataset.createDimension("lat", 120)

    # Copy variables
    for v_name, varin in input_dataset.variables.items():
        outVar = output_dataset.createVariable(v_name, varin.datatype,
                                               varin.dimensions)
        # Copy variable attributes
        outVar.setncatts({k: varin.getncattr(k) for k in varin.ncattrs()})
    # Create variable for downscaling
    for k in range(n_samples):
        downscaled_pr = output_dataset.createVariable(
            "downscaled_pr_{}".format(k), output_dataset['pr'].datatype,
            output_dataset['pr'].dimensions)
        downscaled_pr.setncatts({
            k: output_dataset['pr'].getncattr(k)
            for k in output_dataset['pr'].ncattrs()
        })
        downscaled_pr.standard_name += '_downscaled'
        downscaled_pr.long_name += '_downscaled'
        downscaled_pr.comment = 'downscaled ' + downscaled_pr.comment

        if opt.model == 'gamma_vae':
            # p
            p = output_dataset.createVariable('p_{}'.format(k),
                                              output_dataset['pr'].datatype,
                                              output_dataset['pr'].dimensions)
            p.setncatts({
                k: output_dataset['pr'].getncattr(k)
                for k in output_dataset['pr'].ncattrs()
            })
            p.standard_name += '_p'
            p.long_name += '_p'
            p.comment = 'p ' + p.comment
            # alpha
            alpha = output_dataset.createVariable(
                'alpha_{}'.format(k), output_dataset['pr'].datatype,
                output_dataset['pr'].dimensions)
            alpha.setncatts({
                k: output_dataset['pr'].getncattr(k)
                for k in output_dataset['pr'].ncattrs()
            })
            alpha.standard_name += '_alpha'
            alpha.long_name += '_alpha'
            alpha.comment = 'alpha ' + alpha.comment
            # beta
            beta = output_dataset.createVariable(
                'beta_{}'.format(k), output_dataset['pr'].datatype,
                output_dataset['pr'].dimensions)
            beta.setncatts({
                k: output_dataset['pr'].getncattr(k)
                for k in output_dataset['pr'].ncattrs()
            })
            beta.standard_name += '_beta'
            beta.long_name += '_beta'
            beta.comment = 'beta ' + beta.comment
            # mean_downscaled_pr_{}
            mean_downscaled_pr = output_dataset.createVariable(
                'mean_downscaled_pr_{}'.format(k),
                output_dataset['pr'].datatype, output_dataset['pr'].dimensions)
            mean_downscaled_pr.setncatts({
                k: output_dataset['pr'].getncattr(k)
                for k in output_dataset['pr'].ncattrs()
            })
            mean_downscaled_pr.standard_name += '_mean_downscaled_pr'
            mean_downscaled_pr.long_name += '_mean_downscaled_pr'
            mean_downscaled_pr.comment = 'mean_downscaled_pr ' + mean_downscaled_pr.comment

    bilinear_downscaled_pr = output_dataset.createVariable(
        'bilinear_downscaled_pr', output_dataset['pr'].datatype,
        output_dataset['pr'].dimensions)
    bilinear_downscaled_pr.setncatts({
        k: output_dataset['pr'].getncattr(k)
        for k in output_dataset['pr'].ncattrs()
    })
    bilinear_downscaled_pr.standard_name += '_bilinear_downscaled'
    bilinear_downscaled_pr.long_name += '_bilinear_downscaled'
    bilinear_downscaled_pr.comment = 'bilinear_downscaled ' + bilinear_downscaled_pr.comment

    # set variable values
    output_dataset['lat'][:] = input_dataset['lat'][8:-8]
    output_dataset['lon'][:] = input_dataset['lon'][:]
    output_dataset['time'][:] = input_dataset['time'][climate_data.time_list]

    output_dataset['orog'][:] = input_dataset['orog'][8:-8, :]
    output_dataset['pr'][:] = input_dataset['pr'][climate_data.time_list,
                                                  8:-8, :]

    output_dataset['uas'][:] = input_dataset['uas'][climate_data.time_list,
                                                    8:-8, :]
    output_dataset['vas'][:] = input_dataset['vas'][climate_data.time_list,
                                                    8:-8, :]
    output_dataset['psl'][:] = input_dataset['psl'][climate_data.time_list,
                                                    8:-8, :]

    for idx_lat in range(3):
        for idx_lon in range(18):

            # lat with index 0 is 34 N.
            large_cell_lats = [
                i for i in range(idx_lat * 40 + 4, (idx_lat + 1) * 40 + 12)
            ]
            lats = [i for i in range(idx_lat * 40, (idx_lat + 1) * 40)]
            # longitudes might cross the prime meridian
            large_cell_lons = [
                i % 720
                for i in range(idx_lon * 40 - 4, (idx_lon + 1) * 40 + 4)
            ]
            lons = [i % 720 for i in range(idx_lon * 40, (idx_lon + 1) * 40)]

            pr_tensor = torch.tensor(
                input_dataset['pr'][climate_data.time_list, large_cell_lats,
                                    large_cell_lons],
                dtype=torch.float32,
                device=device).unsqueeze(1)
            orog_tensor = torch.tensor(
                input_dataset['orog'][large_cell_lats, large_cell_lons],
                dtype=torch.float32,
                device=device).expand(len(climate_data.time_list), 1, 48, 48)
            uas_tensor = torch.tensor(
                input_dataset['uas'][climate_data.time_list, large_cell_lats,
                                     large_cell_lons],
                dtype=torch.float32,
                device=device).unsqueeze(1)
            vas_tensor = torch.tensor(
                input_dataset['vas'][climate_data.time_list, large_cell_lats,
                                     large_cell_lons],
                dtype=torch.float32,
                device=device).unsqueeze(1)
            psl_tensor = torch.tensor(
                input_dataset['psl'][climate_data.time_list, large_cell_lats,
                                     large_cell_lons],
                dtype=torch.float32,
                device=device).unsqueeze(1)

            coarse_pr = upscaler.upscale(pr_tensor)
            coarse_uas = upscaler.upscale(uas_tensor)
            coarse_vas = upscaler.upscale(vas_tensor)
            coarse_psl = upscaler.upscale(psl_tensor)

            for k in range(n_samples):
                with torch.no_grad():
                    recon_pr = model.decode(z=torch.randn(len(
                        climate_data.time_list),
                                                          opt.nz,
                                                          1,
                                                          1,
                                                          device=device),
                                            coarse_pr=coarse_pr,
                                            coarse_uas=coarse_uas,
                                            coarse_vas=coarse_vas,
                                            orog=orog_tensor,
                                            coarse_psl=coarse_psl)
                if opt.model == "mse_vae":
                    output_dataset['downscaled_pr_{}'.format(
                        k)][:, lats, lons] = recon_pr[:, 0, 4:-4, 4:-4]
                elif opt.model == "gamma_vae":
                    output_dataset['downscaled_pr_{}'.format(k)][:, lats, lons] = \
                        (torch.distributions.bernoulli.Bernoulli(recon_pr['p']).sample() *
                         torch.distributions.gamma.Gamma(recon_pr['alpha'],1/recon_pr['beta']).sample())[:, 0,4:-4,4:-4]

                    output_dataset['mean_downscaled_pr_{}'.format(k)][:,lats,lons] = \
                    torch.nn.Threshold(0.035807043601739474, 0)(recon_pr['p'] * recon_pr['alpha'] * recon_pr['beta'])[:, 0,4:-4,4:-4]

                    output_dataset['p_{}'.format(
                        k)][:, lats, lons] = recon_pr['p'][:, 0, 4:-4, 4:-4]
                    output_dataset['alpha_{}'.format(
                        k)][:, lats, lons] = recon_pr['alpha'][:, 0, 4:-4,
                                                               4:-4]
                    output_dataset['beta_{}'.format(
                        k)][:, lats, lons] = recon_pr['beta'][:, 0, 4:-4, 4:-4]
                    #todo don't hardcode threshold
                else:
                    raise ValueError("model {} is not implemented".format(
                        opt.model))

            bilinear_pr = torch.nn.functional.upsample(
                coarse_pr,
                scale_factor=opt.scale_factor,
                mode='bilinear',
                align_corners=True)
            output_dataset['bilinear_downscaled_pr'][:, lats,
                                                     lons] = bilinear_pr[:, 0,
                                                                         4:-4,
                                                                         4:-4]
            print('Progress = {:>5.1f} %'.format(
                (idx_lat * 3 + idx_lon + 1) * 100 / (3 * 18)))
    output_dataset.close()
    input_dataset.close()
Beispiel #6
0
def main():
    # set variables, create directories
    opt = BaseOptions().parse()
    n_samples = opt.n_samples
    device = torch.device("cuda" if len(opt.gpu_ids) > 0 else "cpu")
    upscaler = Upscale(size=48, scale_factor=opt.scale_factor, device=device)
    load_root = os.path.join('checkpoints', opt.name)
    load_epoch = opt.load_epoch if opt.load_epoch >= 0 else 'latest'
    load_name = "epoch_{}.pth".format(load_epoch)
    load_dir = os.path.join(load_root, load_name)
    outdir = os.path.join(opt.results_dir, opt.name,
                          '%s_%s' % (opt.phase, load_epoch))
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    climate_data = ClimateDataset(opt=opt, phase=opt.phase)

    # large_cell = 48x48, cell = 40x40, small_cell = 32x32
    large_cell = opt.fine_size + 2 * opt.scale_factor

    # load the model
    model = SDVAE(opt=opt, device=device).to(device)
    model.load_state_dict(torch.load(load_dir, map_location='cpu'))
    model.eval()

    # Iterate val cells and compute #n_samples reconstructions.
    index = 0
    # read input file
    input_dataset = Dataset(os.path.join(opt.dataroot, 'dataset.nc4'),
                            "r",
                            format="NETCDF4")
    for idx_lat in range(climate_data.rows):
        for idx_lon in climate_data.lat_lon_list[idx_lat]:
            # calculate upper left index for cell with boundary values to downscale
            anchor_lat = idx_lat * climate_data.cell_size + climate_data.scale_factor
            anchor_lon = idx_lon * climate_data.cell_size
            # select indices for a 48 x 48 box around the 32 x 32 box to be downscaled (with boundary values)
            large_cell_lats = [
                i for i in range(
                    anchor_lat - climate_data.scale_factor, anchor_lat +
                    climate_data.fine_size + climate_data.scale_factor)
            ]
            # longitudes might cross the prime meridian
            large_cell_lons = [
                i % 720 for i in range(
                    anchor_lon - climate_data.scale_factor, anchor_lon +
                    climate_data.fine_size + climate_data.scale_factor)
            ]
            # create output path
            basename = "val.lat{}_lon{}.nc4".format(anchor_lat, anchor_lon)
            print("test file nr. {} name: {}".format(index, basename))

            # create output file
            output_dataset_path = os.path.join(outdir, basename)
            output_dataset = Dataset(output_dataset_path,
                                     "w",
                                     format="NETCDF4")
            output_dataset.setncatts({
                k: input_dataset.getncattr(k)
                for k in input_dataset.ncattrs()
            })
            # add own metadata
            output_dataset.creators = "Simon Treu (EDGAN, [email protected])\n" + output_dataset.creators
            output_dataset.history = datetime.date.today().isoformat(
            ) + " Added Downscaled images in python with pix2pix-edgan\n" + output_dataset.history

            output_dataset.createDimension("time", None)
            output_dataset.createDimension("lon", large_cell)
            output_dataset.createDimension("lat", large_cell)

            # Copy variables
            for v_name, varin in input_dataset.variables.items():
                outVar = output_dataset.createVariable(v_name, varin.datatype,
                                                       varin.dimensions)
                # Copy variable attributes
                outVar.setncatts(
                    {k: varin.getncattr(k)
                     for k in varin.ncattrs()})
            # Create variable for downscaling
            for k in range(n_samples):
                downscaled_pr = output_dataset.createVariable(
                    "downscaled_pr_{}".format(k),
                    output_dataset['pr'].datatype,
                    output_dataset['pr'].dimensions)
                downscaled_pr.setncatts({
                    k: output_dataset['pr'].getncattr(k)
                    for k in output_dataset['pr'].ncattrs()
                })
                downscaled_pr.standard_name += '_downscaled'
                downscaled_pr.long_name += '_downscaled'
                downscaled_pr.comment = 'downscaled ' + downscaled_pr.comment

                if opt.model == 'gamma_vae':
                    # p
                    p = output_dataset.createVariable(
                        'p_{}'.format(k), output_dataset['pr'].datatype,
                        output_dataset['pr'].dimensions)
                    p.setncatts({
                        k: output_dataset['pr'].getncattr(k)
                        for k in output_dataset['pr'].ncattrs()
                    })
                    p.standard_name += '_p'
                    p.long_name += '_p'
                    p.comment = 'p ' + p.comment
                    # alpha
                    alpha = output_dataset.createVariable(
                        'alpha_{}'.format(k), output_dataset['pr'].datatype,
                        output_dataset['pr'].dimensions)
                    alpha.setncatts({
                        k: output_dataset['pr'].getncattr(k)
                        for k in output_dataset['pr'].ncattrs()
                    })
                    alpha.standard_name += '_alpha'
                    alpha.long_name += '_alpha'
                    alpha.comment = 'alpha ' + alpha.comment
                    # beta
                    beta = output_dataset.createVariable(
                        'beta_{}'.format(k), output_dataset['pr'].datatype,
                        output_dataset['pr'].dimensions)
                    beta.setncatts({
                        k: output_dataset['pr'].getncattr(k)
                        for k in output_dataset['pr'].ncattrs()
                    })
                    beta.standard_name += '_beta'
                    beta.long_name += '_beta'
                    beta.comment = 'beta ' + beta.comment
                    # mean_downscaled_pr_{}
                    mean_downscaled_pr = output_dataset.createVariable(
                        'mean_downscaled_pr_{}'.format(k),
                        output_dataset['pr'].datatype,
                        output_dataset['pr'].dimensions)
                    mean_downscaled_pr.setncatts({
                        k: output_dataset['pr'].getncattr(k)
                        for k in output_dataset['pr'].ncattrs()
                    })
                    mean_downscaled_pr.standard_name += '_mean_downscaled_pr'
                    mean_downscaled_pr.long_name += '_mean_downscaled_pr'
                    mean_downscaled_pr.comment = 'mean_downscaled_pr ' + mean_downscaled_pr.comment

            bilinear_downscaled_pr = output_dataset.createVariable(
                'bilinear_downscaled_pr', output_dataset['pr'].datatype,
                output_dataset['pr'].dimensions)
            bilinear_downscaled_pr.setncatts({
                k: output_dataset['pr'].getncattr(k)
                for k in output_dataset['pr'].ncattrs()
            })
            bilinear_downscaled_pr.standard_name += '_bilinear_downscaled'
            bilinear_downscaled_pr.long_name += '_bilinear_downscaled'
            bilinear_downscaled_pr.comment = 'bilinear_downscaled ' + bilinear_downscaled_pr.comment

            # set variable values
            output_dataset['lat'][:] = input_dataset['lat'][large_cell_lats]
            output_dataset['lon'][:] = input_dataset['lon'][large_cell_lons]
            output_dataset['time'][:] = input_dataset['time'][:]

            output_dataset['orog'][:] = input_dataset['orog'][large_cell_lats,
                                                              large_cell_lons]
            output_dataset['pr'][:] = input_dataset['pr'][:, large_cell_lats,
                                                          large_cell_lons]
            for k in range(n_samples):
                output_dataset['downscaled_pr_{}'.format(
                    k)][:] = input_dataset['pr'][:, large_cell_lats,
                                                 large_cell_lons]
            output_dataset['bilinear_downscaled_pr'][:] = input_dataset[
                'pr'][:, large_cell_lats, large_cell_lons]

            output_dataset['uas'][:] = input_dataset['uas'][:, large_cell_lats,
                                                            large_cell_lons]
            output_dataset['vas'][:] = input_dataset['vas'][:, large_cell_lats,
                                                            large_cell_lons]
            output_dataset['psl'][:] = input_dataset['psl'][:, large_cell_lats,
                                                            large_cell_lons]

            # read out the variables similar to construct_datasets.py
            pr = output_dataset['pr'][:]

            uas = output_dataset['uas'][:]
            vas = output_dataset['vas'][:]
            psl = output_dataset['psl'][:]
            orog = output_dataset['orog'][:]

            times = pr.shape[0]
            for t in range(times):
                pr_tensor = torch.tensor(pr[t, :, :],
                                         dtype=torch.float32,
                                         device=device)
                orog_tensor = torch.tensor(
                    orog[:, :], dtype=torch.float32,
                    device=device).unsqueeze(0).unsqueeze(0)
                uas_tensor = torch.tensor(uas[t, :, :],
                                          dtype=torch.float32,
                                          device=device)
                vas_tensor = torch.tensor(vas[t, :, :],
                                          dtype=torch.float32,
                                          device=device)
                psl_tensor = torch.tensor(psl[t, :, :],
                                          dtype=torch.float32,
                                          device=device)

                coarse_pr = upscaler.upscale(pr_tensor).unsqueeze(0).unsqueeze(
                    0)
                coarse_uas = upscaler.upscale(uas_tensor).unsqueeze(
                    0).unsqueeze(0)
                coarse_vas = upscaler.upscale(vas_tensor).unsqueeze(
                    0).unsqueeze(0)
                coarse_psl = upscaler.upscale(psl_tensor).unsqueeze(
                    0).unsqueeze(0)

                for k in range(n_samples):
                    with torch.no_grad():
                        recon_pr = model.decode(z=torch.randn(1,
                                                              opt.nz,
                                                              1,
                                                              1,
                                                              device=device),
                                                coarse_pr=coarse_pr,
                                                coarse_uas=coarse_uas,
                                                coarse_vas=coarse_vas,
                                                orog=orog_tensor,
                                                coarse_psl=coarse_psl)
                    if opt.model == "mse_vae":
                        output_dataset['downscaled_pr_{}'.format(k)][
                            t, :, :] = recon_pr
                    elif opt.model == "gamma_vae":
                        output_dataset['downscaled_pr_{}'.format(k)][t, :, :] = \
                            (torch.distributions.bernoulli.Bernoulli(recon_pr['p']).sample() *
                             torch.distributions.gamma.Gamma(recon_pr['alpha'],1/recon_pr['beta']).sample())[0, 0,:,:]

                        output_dataset['mean_downscaled_pr_{}'.format(k)][t,:,:] = \
                        torch.nn.Threshold(0.035807043601739474, 0)(recon_pr['p'] * recon_pr['alpha'] * recon_pr['beta'])

                        output_dataset['p_{}'.format(k)][
                            t, :, :] = recon_pr['p']
                        output_dataset['alpha_{}'.format(k)][
                            t, :, :] = recon_pr['alpha']
                        output_dataset['beta_{}'.format(k)][
                            t, :, :] = recon_pr['beta']
                        #todo don't hardcode threshold
                    else:
                        raise ValueError("model {} is not implemented".format(
                            opt.model))

                bilinear_pr = torch.nn.functional.upsample(
                    coarse_pr,
                    scale_factor=opt.scale_factor,
                    mode='bilinear',
                    align_corners=True)
                output_dataset['bilinear_downscaled_pr'][
                    t, :, :] = bilinear_pr[0, 0, :, :]

            # read out the variables
            # create reanalysis result (croped to 32*32)
            # compute downscaled images
            # save them in result dataset
            output_dataset.close()
            index += 1
    input_dataset.close()
Beispiel #7
0
class SDVAE(nn.Module):
    def __init__(self, opt, device):
        super(SDVAE, self).__init__()
        # variables
        self.nz = opt.nz
        self.input_size = opt.fine_size ** 2
        self.upscaler = Upscale(size=48, scale_factor=8, device=device)
        self.use_orog = not opt.no_orog
        self.no_dropout = opt.no_dropout
        self.nf_encoder = opt.nf_encoder
        self.model = opt.model
        self.scale_factor = opt.scale_factor
        self.fine_size = opt.fine_size
        self.device = device
        self.regression = opt.regression


        # dimensions for batch_size=64, nf_encoder=16, fine_size=32, nz=10, orog=True
        # 64x5x48x48
        self.h_layer1 = self._down_conv(in_channels=1 + self.use_orog + 4,
                                        out_channels=self.nf_encoder,
                                        kernel_size=4, padding=1, stride=2)
        # 64x20x24x24
        self.h_layer2 = self._down_conv(in_channels=self.nf_encoder + self.use_orog + 4,
                                        out_channels=self.nf_encoder * 2,
                                        kernel_size=4, padding=1, stride=2)
        # 64x20x12x12
        self.h_layer3 = self._down_conv(in_channels=self.nf_encoder*2 + self.use_orog + 4,
                                        out_channels=self.nf_encoder * 4,
                                        kernel_size=4, padding=1, stride=2)
        # 64x35x6x6
        self.h_layer4 = self._down_conv(in_channels=4 * self.nf_encoder + self.use_orog + 4
                                        , out_channels=self.nf_encoder * 8,
                                        kernel_size=4, padding=1, stride=2)
        # 64x48x3x3

        # mu
        self.mu = nn.Sequential(nn.Linear(in_features=self.nf_encoder * 8 * 9, out_features=self.nz))
        # 64x10x1x1

        # log_var
        self.log_var = nn.Sequential(nn.Linear(in_features=self.nf_encoder * 8 * 9, out_features=self.nz))
        # 64x10x1x1

        self.decode = Decoder(opt, device)

    def forward(self, fine_pr, coarse_pr, orog, coarse_uas, coarse_vas, coarse_psl):
        if not self.regression:
            # layer 1
            input_layer_input = [fine_pr]
            if self.use_orog:
                input_layer_input.append(orog)

            upsample48 = torch.nn.Upsample(scale_factor=self.scale_factor, mode='nearest')
            input_layer_input += [upsample48(coarse_pr), upsample48(coarse_uas),
                                  upsample48(coarse_vas), upsample48(coarse_psl)]
            h_layer1 = self.h_layer1(torch.cat(input_layer_input, 1))

            # layer 2
            layer2_input = [h_layer1]
            if self.use_orog:
                upscale24 = Upscale(size=48, scale_factor=2, device=self.device)
                layer2_input.append(upscale24.upscale(orog))
            upsample24 = torch.nn.Upsample(scale_factor=4, mode='nearest')
            layer2_input += [upsample24(coarse_pr), upsample24(coarse_uas),
                             upsample24(coarse_vas), upsample24(coarse_psl)]
            h_layer2 = self.h_layer2(torch.cat(layer2_input, 1))

            # layer 3
            layer3_input = [h_layer2]
            if self.use_orog:
                upscale12 = Upscale(size=48, scale_factor=4, device=self.device)
                layer3_input.append(upscale12.upscale(orog))
            upsample12 = torch.nn.Upsample(scale_factor=2, mode='nearest')
            layer3_input += [upsample12(coarse_pr), upsample12(coarse_uas),
                             upsample12(coarse_vas), upsample12(coarse_psl)]
            h_layer3 = self.h_layer3(torch.cat(layer3_input, 1))

            # layer 4
            layer4_input = [h_layer3]
            if self.use_orog:
                upscale6 = Upscale(size=48, scale_factor=8, device=self.device)
                layer4_input.append(upscale6.upscale(orog))
            layer4_input += [coarse_pr, coarse_uas, coarse_vas, coarse_psl]
            h_layer4 = self.h_layer4(torch.cat(layer4_input, 1))

            # output layer
            mu = self.mu(h_layer4.view(h_layer4.shape[0], -1)).unsqueeze(-1).unsqueeze(-1)
            log_var = self.log_var(h_layer4.view(h_layer4.shape[0], -1)).unsqueeze(-1).unsqueeze(-1)
        else:
            mu = torch.zeros(fine_pr.shape[0], self.nz, 1, 1, dtype=torch.float, device=self.device)
            log_var = torch.zeros(fine_pr.shape[0], self.nz, 1, 1, dtype=torch.float, device=self.device)
        # reparameterization
        z = self._reparameterize(mu, log_var)

        # decode
        recon_pr = self.decode(z=z, coarse_pr=coarse_pr,orog=orog,
                               coarse_uas=coarse_uas, coarse_vas=coarse_vas, coarse_psl=coarse_psl)
        return recon_pr, mu.view(-1, self.nz), log_var.view(-1, self.nz)

    def loss_function(self, recon_x, x, mu, log_var, coarse_pr,):
        # negative log predictive density
        if self.model == 'gamma_vae':
            nlpd = self._neg_log_gamma_likelihood(x, recon_x['alpha'], recon_x['beta'], recon_x['p'])
        elif self.model == 'mse_vae':
            nlpd = nn.functional.mse_loss(recon_x, x, size_average=False)/x.shape[0]

        # Kullback-Leibler Divergence
        kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(),1))
        loss = kld + nlpd

        # cycle loss
        # not added to the total loss
        if self.model == 'gamma_vae':
            coarse_recon = self.upscaler.upscale(recon_x['p'] * recon_x['alpha'] * recon_x['beta'])
        elif self.model == 'mse_vae':
            coarse_recon = self.upscaler.upscale(recon_x)
        cycle_loss = nn.functional.mse_loss(coarse_pr, coarse_recon, size_average=True)

        return nlpd, kld, cycle_loss, loss

    def _reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def _neg_log_gamma_likelihood(self, x, alpha, beta, p):
        result = torch.zeros(1)
        if (x > 0).any():
            result = - torch.sum(torch.log(p[x > 0])
                        + (alpha[x > 0] - 1) * torch.log(x[x > 0])
                        - alpha[x > 0] * torch.log(beta[x > 0])
                        - x[x > 0] / beta[x > 0]
                        - torch.lgamma(alpha[x > 0])
                        )
        if (x == 0).any():
            result -= torch.sum(torch.log(1 - p[x == 0]) + 0 * alpha[x == 0] + 0 * beta[x == 0])
        return result/x.shape[0]  # mean over batch size

    def _down_conv(self, in_channels, out_channels, kernel_size, padding, stride):
        if self.no_dropout:
            return nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                           kernel_size=kernel_size, padding=padding, stride=stride),
                                 nn.BatchNorm2d(out_channels),
                                 nn.ReLU())
        else:
            return nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                           kernel_size=kernel_size, padding=padding, stride=stride),
                                 nn.BatchNorm2d(out_channels),
                                 nn.ReLU(),
                                 nn.Dropout())
Beispiel #8
0
class ClimateDataset(Dataset):
    def __init__(self, opt, phase):
        self.root = opt.dataroot
        self.scale_factor = opt.scale_factor
        self.fine_size = opt.fine_size
        self.n_test = opt.n_test
        self.n_val = opt.n_val
        # cell size is the size of one cell that is extracted from the netcdf file. Afterwards this
        # is cropped to fine_size
        self.cell_size = opt.fine_size + self.scale_factor
        with Nc4Dataset(os.path.join(self.root, "dataset.nc4"), "r", format="NETCDF4") as file:
            check_dimensions(file, self.cell_size, self.scale_factor)
            cols = file['lon'].size//self.cell_size
            self.rows = file['lat'].size//self.cell_size

        self.upscaler = Upscale(size=self.fine_size+2*self.scale_factor, scale_factor=self.scale_factor)

        # remove 4 40 boxes to create a training set for each block of 40 rows
        self.lat_lon_list = create_lat_lon_indices(self.rows, cols, self.n_test, self.n_val,
                                                   seed=opt.seed, phase=phase)
        # remove 3 years for test and 3 years for val
        self.time_list = create_time_list(seed=opt.seed, phase=phase)

        t = len(self.time_list)
        if phase == 'train':
            self.length = self.rows * (cols - self.n_test - self.n_val) * t
        elif phase == 'val':
            self.length = self.rows * self.n_val * t
        elif phase == 'test':
            self.length = self.rows * self.n_test * t
        pass

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        start_time = time.time() # todo remove timing

        # ++ calculate lat lon and time from index ++ #
        s_lat = len(self.lat_lon_list)
        s_lon = len(self.lat_lon_list[0])

        t_idx = index // (s_lat * s_lon)
        lat = index % (s_lat * s_lon) // s_lon
        lon = index % s_lon
        # --------------------------------------------------------------------------------------------------------------
        # calculate a random offset from the upper left corner to crop the box
        w_offset = random.randint(0, self.scale_factor-1)
        h_offset = random.randint(0, self.scale_factor-1)
        # --------------------------------------------------------------------------------------------------------------
        # calculate the lat and lon indices of the upper left corner in the netcdf file
        # add scale factor because first 8 pixels are only for boundary conditions
        # --> for lat=0 the index in the netcdf file is 8.
        anchor_lat = lat * self.cell_size + w_offset + self.scale_factor
        anchor_lon = self.lat_lon_list[lat][lon] * self.cell_size + h_offset

        # select indices for a 48 x 48 box around the 32 x 32 box to be downscaled (with boundary values)
        boundary_lats = [i for i in range(anchor_lat-self.scale_factor, anchor_lat+self.fine_size+self.scale_factor)]
        # longitudes might cross the prime meridian
        boundary_lons = [i % 720
                         for i in range(anchor_lon-self.scale_factor, anchor_lon+self.fine_size+self.scale_factor)]
        t=self.time_list[t_idx]
        # --------------------------------------------------------------------------------------------------------------
        # ++ read data ++ #
        with Nc4Dataset(os.path.join(self.root, "dataset.nc4"), "r", format="NETCDF4") as file:
            pr = torch.tensor(file['pr'][t, boundary_lats, boundary_lons], dtype=torch.float)
            orog = torch.tensor(file['orog'][boundary_lats, boundary_lons],  dtype=torch.float)
            uas = torch.tensor(file['uas'][t, boundary_lats, boundary_lons], dtype=torch.float)
            vas = torch.tensor(file['vas'][t, boundary_lats, boundary_lons], dtype=torch.float)
            psl = torch.tensor(file['psl'][t, boundary_lats, boundary_lons], dtype=torch.float)
        # --------------------------------------------------------------------------------------------------------------
        coarse_pr = self.upscaler.upscale(pr)
        coarse_uas = self.upscaler.upscale(uas)
        coarse_vas = self.upscaler.upscale(vas)
        coarse_psl = self.upscaler.upscale(psl)

        # bring all into shape [C,W,H] (Channels, With, Height)
        pr.unsqueeze_(0)
        orog.unsqueeze_(0)
        coarse_pr.unsqueeze_(0)
        coarse_uas.unsqueeze_(0)
        coarse_vas.unsqueeze_(0)
        coarse_psl.unsqueeze_(0)

        end_time = time.time() - start_time

        return {'fine_pr': pr, 'coarse_pr': coarse_pr,
                'orog': orog, 'coarse_uas': coarse_uas,
                'coarse_vas': coarse_vas, 'coarse_psl': coarse_psl,
                'time': end_time}