def __init__(self,
                 start_date=date(1990, 1, 1),
                 end_date=date(1990, 12, 31),
                 regin="AUS",
                 lr_transform=None,
                 hr_transform=None,
                 shuffle=True,
                 args=None):
        print("=> BARRA_R & ACCESS_S1 loading")
        print("=> from " + start_date.strftime("%Y/%m/%d") + " to " +
              end_date.strftime("%Y/%m/%d") + "")
        self.file_BARRA_dir = args.file_BARRA_dir
        self.file_ACCESS_dir = args.file_ACCESS_dir
        self.args = args

        self.lr_transform = lr_transform
        self.hr_transform = hr_transform

        self.start_date = start_date
        self.end_date = end_date

        self.regin = regin
        self.leading_time_we_use = args.leading_time_we_use

        self.ensemble_access = [
            'e01', 'e02', 'e03', 'e04', 'e05', 'e06', 'e07', 'e08', 'e09',
            'e10', 'e11'
        ]
        self.ensemble = []
        for i in range(args.ensemble):
            self.ensemble.append(self.ensemble_access[i])

        self.dates = self.date_range(start_date, end_date)

        self.filename_list = self.get_filename_with_time_order(
            args.file_ACCESS_dir + "pr/daily/")
        if not os.path.exists(args.file_ACCESS_dir + "pr/daily/"):
            print(args.file_ACCESS_dir + "pr/daily/")
            print("no file or no permission")

        _, _, date_for_BARRA, time_leading = self.filename_list[0]
        if shuffle:
            random.shuffle(self.filename_list)

        if not os.path.exists(
                "/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/1990/01/accum_prcp-fc-spec-PT1H-BARRA_R-v1-19900109T0600Z.sub.nc"
        ):
            print(self.file_BARRA_dir)
            print("no file or no permission!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        data_high = dpt.read_barra_data_fc_get_lat_lon(self.file_BARRA_dir,
                                                       date_for_BARRA)
        self.lat = data_high[1]
        self.lon = data_high[1]
        self.shape = (79, 94)
        if self.args.dem:
            data_dem = dpt.add_lat_lon(
                dpt.read_dem(args.file_DEM_dir + "dem-9s1.tif"))
            self.dem_data = dpt.interp_tensor_2d(
                dpt.map_aust_old(data_dem, xrarray=False), self.shape)
示例#2
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        #read_data filemame[idx]
        access_filename_pr, date_for_BARRA, time_leading = self.filename_list[
            idx]
        #         print(type(date_for_BARRA))
        #         low_filename,high_filename,time_leading=self.filename_list[idx]

        data_low = dpt.read_access_data(access_filename_pr, idx=time_leading)
        lr_raw = dpt.map_aust(data_low, domain=args.domain, xrarray=False)

        #         domain = [train_data.lon.data.min(), train_data.lon.data.max(), train_data.lat.data.min(), train_data.lat.data.max()]
        #         print(domain)

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        label = dpt.map_aust(data_high, domain=args.domain,
                             xrarray=False)  #,domain=domain)
        lr = dpt.interp_tensor_2d(lr_raw, (78, 100))
        if self.transform:  #channel 数量需要整理!!
            if self.args.channels == 27:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(lr_psl), self.transform(
                        lr_zg), self.transform(lr_tasmax), self.transform(
                            lr_tasmin), self.transform(label), torch.tensor(
                                int(date_for_BARRA.strftime(
                                    "%Y%m%d"))), torch.tensor(time_leading)
            elif self.args.channels == 5:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(lr_psl), self.transform(
                        lr_tasmax), self.transform(lr_tasmin), self.transform(
                            label), torch.tensor(
                                int(date_for_BARRA.strftime(
                                    "%Y%m%d"))), torch.tensor(time_leading)
            if self.args.channels == 2:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(label), torch.tensor(
                        int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                            time_leading)

        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)
示例#3
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        #read_data filemame[idx]
        access_filename, date_for_BARRA, time_leading = self.filename_list[idx]
        data_low = dpt.read_access_data(access_filename, idx=time_leading)
        lr_raw = dpt.map_aust(data_low, domain=args.domain, xrarray=False)

        lr = dpt.interp_tensor_2d(lr_raw, (78, 100))

        if self.transform:
            return self.transform(np.expand_dims(lr, axis=3) * 86400)
        else:
            return np.expand_dims(lr, axis=3) * 86400, np.expand_dims(
                label,
                axis=3), torch.tensor(int(date_for_BARRA.strftime(
                    "%Y%m%d"))), torch.tensor(time_leading)
示例#4
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        t = time.time()

        #read_data filemame[idx]
        access_filename_pr, en, access_date, date_for_BARRA, time_leading = self.filename_list[
            idx]
        #         print(type(date_for_BARRA))
        #         low_filename,high_filename,time_leading=self.filename_list[idx]

        lr = dpt.read_access_data(
            access_filename_pr, idx=time_leading).data[82:144, 134:188] * 86400
        #         lr=dpt.map_aust(lr,domain=self.args.domain,xrarray=False)
        lr = np.expand_dims(dpt.interp_tensor_2d(lr, self.shape), axis=2)
        lr.dtype = "float32"

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        label = dpt.map_aust(data_high, domain=self.args.domain,
                             xrarray=False)  #,domain=domain)

        if self.args.zg:
            access_filename_zg = self.args.file_ACCESS_dir + "zg/daily/" + en + "/" + "da_zg_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_zg = dpt.read_access_zg(access_filename_zg,
                                       idx=time_leading).data[:][83:145,
                                                                 135:188]
            lr_zg = dpt.interp_tensor_3d(lr_zg, self.shape)

        if self.args.psl:
            access_filename_psl = self.args.file_ACCESS_dir + "psl/daily/" + en + "/" + "da_psl_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_psl = dpt.read_access_data(access_filename_psl,
                                          var_name="psl",
                                          idx=time_leading).data[82:144,
                                                                 134:188]
            lr_psl = dpt.interp_tensor_2d(lr_psl, self.shape)

        if self.args.tasmax:
            access_filename_tasmax = self.args.file_ACCESS_dir + "tasmax/daily/" + en + "/" + "da_tasmax_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_tasmax = dpt.read_access_data(access_filename_tasmax,
                                             var_name="tasmax",
                                             idx=time_leading).data[82:144,
                                                                    134:188]
            lr_tasmax = dpt.interp_tensor_2d(lr_tasmax, self.shape)

        if self.args.tasmin:
            access_filename_tasmin = self.args.file_ACCESS_dir + "tasmin/daily/" + en + "/" + "da_tasmin_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_tasmin = dpt.read_access_data(access_filename_tasmin,
                                             var_name="tasmin",
                                             idx=time_leading).data[82:144,
                                                                    134:188]
            lr_tasmin = dpt.interp_tensor_2d(lr_tasmin, self.shape)

#         if self.args.dem:
# #             print("add dem data")
#             lr=np.concatenate((lr,np.expand_dims(self.dem_data,axis=2)),axis=2)

#         print("end loading one data,time cost %f"%(time.time()-t))

        if self.transform:  #channel 数量需要整理!!
            if self.args.channels == 27:
                return self.transform(lr), self.transform(
                    self.dem_data), self.transform(lr_psl), self.transform(
                        lr_zg), self.transform(lr_tasmax), self.transform(
                            lr_tasmin), self.transform(label), torch.tensor(
                                int(date_for_BARRA.strftime(
                                    "%Y%m%d"))), torch.tensor(time_leading)
            if self.args.channels == 2:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(label), torch.tensor(
                        int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                            time_leading)

        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)
示例#5
0
def main():
    import netCDF4 as nc
    sys = platform.system()
    init_date = date(1970, 1, 1)
    start_date = date(1990, 1, 1)
    end_date = date(2012, 12, 25)  #if 929 is true we should substract 1 day
    if sys == "Windows":
        init_date = date(1970, 1, 1)
        start_date = date(1990, 1, 1)
        end_date = date(2012, 12,
                        25)  #if 929 is true we should substract 1 day
        file_ACCESS_dir = "H:/climate/access-s1/"
        file_BARRA_dir = "D:/dataset/accum_prcp/"
        #         args.file_ACCESS_dir="E:/climate/access-s1/"
        #         args.file_BARRA_dir="C:/Users/JIA059/barra/"
        file_DEM_dir = "../DEM/"
    else:
        file_ACCESS_dir_pr = "/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/"
        file_ACCESS_dir = "/g/data/ub7/access-s1/hc/raw_model/atmos/"
        # training_name="temp01"
        file_BARRA_dir = "/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/"

    nine2nine = True
    date_minus_one = 1
    leading_time_we_use = 7
    ensemble = ['e10', 'e11']

    var_name = "pr"
    dates = [
        start_date + timedelta(x)
        for x in range((end_date - start_date).days + 1)
    ]
    file_list = get_filename_with_time_order(
        file_ACCESS_dir + var_name + '/daily/', ensemble, dates, var_name)
    time_leading = 7

    lat_name = "lat"
    lon_name = "lon"

    #     print(file_list)
    for i in file_list:
        #         print(i)
        data = Dataset(i[0], 'r')
        var = data[var_name][:7, 82:144, 134:188]
        lat = data[lat_name][:][82:144]
        lon = data[lon_name][:][134:188]
        #         print(var.shape)
        data.close()
        #         lr=dpt.read_access_data(i,idx=time_leading).data[82:144,134:188]*86400
        result = np.zeros((7, 79, 94))
        for idx, j in enumerate(var):
            result[idx] = dpt.interp_tensor_2d(j, (79, 94))

        if not os.path.exists('../data/' + var_name + '/daily/' + i[1]):
            os.mkdir('../data/' + var_name + '/daily/' + i[1])

        f_w = nc.Dataset('../data/' + var_name + '/daily/' + i[1] + "/da_" +
                         var_name + "_" + i[2].strftime("%Y%m%d") + "_" +
                         i[1] + '.nc',
                         'w',
                         format='NETCDF4')
        f_w.createDimension('lat', 79)
        f_w.createDimension('lon', 94)
        f_w.createDimension('time', time_leading)

        f_w.createVariable('lat', np.float32, ('lat'))
        f_w.createVariable('lon', np.float32, ('lon'))
        f_w.createVariable('time', np.int, ('time'))

        f_w.variables['lat'][:] = np.zeros((79))
        f_w.variables['lon'][:] = np.zeros((94))
        f_w.variables['time'][:] = np.zeros((7))

        f_w.createVariable(var_name, np.float32, ('time', 'lat', 'lon'))
        f_w.variables[var_name][:] = result

        f_w.close()
示例#6
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        #         t=time.time()

        #read_data filemame[idx]
        access_filename_pr, access_date, date_for_BARRA, time_leading = self.filename_list[
            idx]
        #         print(type(date_for_BARRA))
        #         low_filename,high_filename,time_leading=self.filename_list[idx]

        data_low = dpt.read_access_data(access_filename_pr, idx=time_leading)
        lr_raw = dpt.map_aust(data_low, domain=self.args.domain, xrarray=False)

        #         domain = [train_data.lon.data.min(), train_data.lon.data.max(), train_data.lat.data.min(), train_data.lat.data.max()]
        #         print(domain)

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        label = dpt.map_aust(data_high, domain=self.args.domain,
                             xrarray=False)  #,domain=domain)
        lr = dpt.interp_tensor_2d(lr_raw, (78, 100))

        if self.args.zg:
            access_filename_zg = self.args.file_ACCESS_dir + "zg/daily/" + en + "/" + "da_zg_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_zg = dpt.read_access_zg(access_filename_zg, idx=time_leading)
            data_zg_aus = map_aust(data_zg, xrarray=False)
            lr_zg = dpt.interp_tensor_3d(data_zg_aus, (78, 100))

        if self.args.psl:
            access_filename_psl = self.args.file_ACCESS_dir + "psl/daily/" + en + "/" + "da_psl_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_psl = dpt.read_access_data(access_filename_psl,
                                            idx=time_leading)
            data_psl_aus = map_aust(data_psl, xrarray=False)
            lr_psl = dpt.interp_tensor_2d(data_psl_aus, (78, 100))
        if self.args.tasmax:
            access_filename_tasmax = self.args.file_ACCESS_dir + "tasmax/daily/" + en + "/" + "da_tasmax_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_tasmax = dpt.read_access_data(access_filename_tasmax,
                                               idx=time_leading)
            data_tasmax_aus = map_aust(data_tasmax, xrarray=False)
            lr_tasmax = dpt.interp_tensor_2d(data_tasmax_aus, (78, 100))

        if self.args.tasmax:
            access_filename_tasmin = self.args.file_ACCESS_dir + "tasmin/daily/" + en + "/" + "da_tasmin_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_tasmin = dpt.read_access_data(access_filename_tasmin,
                                               idx=time_leading)
            data_tasmin_aus = map_aust(data_tasmin, xrarray=False)
            lr_tasmin = dpt.interp_tensor_2d(data_tasmin_aus, (78, 100))

#         print("end loading one data,time cost %f"%(time.time()-t))

        if self.transform:  #channel 数量需要整理!!
            return self.transform(
                lr * 86400), self.transform(label), torch.tensor(
                    int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                        time_leading)
        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)