예제 #1
0
    def __init__(self,
                 start_date=date(1990, 1, 1),
                 end_date=date(1990, 12, 31),
                 regin="AUS",
                 transform=None,
                 train=True,
                 args=None):
        print("=> BARRA_R & ACCESS_S1 loading")
        print("=> from " + start_date.strftime("%Y/%m/%d") + " to " +
              end_date.strftime("%Y/%m/%d") + "")
        self.file_BARRA_dir = args.file_BARRA_dir
        self.file_ACCESS_dir = args.file_ACCESS_dir
        self.args = args

        self.transform = transform
        self.start_date = start_date
        self.end_date = end_date

        self.scale = args.scale[0]
        self.regin = regin
        self.leading_time = 217
        self.leading_time_we_use = args.leading_time_we_use

        self.ensemble_access = [
            'e01', 'e02', 'e03', 'e04', 'e05', 'e06', 'e07', 'e08', 'e09',
            'e10', 'e11'
        ]
        self.ensemble = []
        for i in range(args.ensemble):
            self.ensemble.append(self.ensemble_access[i])

        self.dates = self.date_range(start_date, end_date)

        self.filename_list = self.get_filename_with_time_order(
            args.file_ACCESS_dir + "pr/daily/")
        if not os.path.exists(args.file_ACCESS_dir + "pr/daily/"):
            print(args.file_ACCESS_dir + "pr/daily/")
            print("no file or no permission")

        _, _, _, date_for_BARRA, time_leading = self.filename_list[0]
        if not os.path.exists(
                "/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/1990/01/accum_prcp-fc-spec-PT1H-BARRA_R-v1-19900109T0600Z.sub.nc"
        ):
            print(self.file_BARRA_dir)
            print("no file or no permission!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        data_exp = dpt.map_aust(data_high, domain=args.domain,
                                xrarray=True)  #,domain=domain)
        self.lat = data_exp["lat"]
        self.lon = data_exp["lon"]
        self.shape = (79, 94)
        if self.args.dem:
            data_dem = dpt.add_lat_lon(
                dpt.read_dem(args.file_DEM_dir + "dem-9s1.tif"))
            self.dem_data = dpt.interp_tensor_2d(
                dpt.map_aust_old(data_dem, xrarray=False), self.shape)
예제 #2
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        t = time.time()

        #read_data filemame[idx]
        en, access_date, barra_date, time_leading = self.filename_list[idx]

        lr = dpt.read_access_data(self.file_ACCESS_dir, en, access_date,
                                  time_leading, "pr")
        lr = np.expand_dims(lr, axis=2)

        #         lr=np.expand_dims(self.mapping(lr),axis=2)
        label = dpt.read_barra_data_fc(self.file_BARRA_dir, barra_date)

        if self.args.zg:
            lr_zg = np.expand_dims(dpt.read_access_data(
                self.file_ACCESS_dir, en, access_date, time_leading, "zg"),
                                   axis=2)

        if self.args.psl:
            lr_psl = dpt.read_access_data(self.file_ACCESS_dir, en,
                                          access_date, time_leading, "psl")

        if self.args.tasmax:
            lr_tasmax = np.expand_dims(dpt.read_access_data(
                self.file_ACCESS_dir, en, access_date, time_leading, "tasmax"),
                                       axis=2)

        if self.args.tasmin:
            lr_tasmin = dpt.read_access_data(self.file_ACCESS_dir, en,
                                             access_date, time_leading,
                                             "tasmin")


#         if self.args.channels==1:
#             lr=np.repeat(lr,3,axis=2)

        if self.transform:  #channel 数量需要整理!!

            return self.transform(lr), self.transform(
                self.dem_data), self.transform(label), torch.tensor(
                    int(barra_date.strftime("%Y%m%d"))), torch.tensor(
                        time_leading)
        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)
예제 #3
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        #read_data filemame[idx]
        access_filename_pr, date_for_BARRA, time_leading = self.filename_list[
            idx]
        #         print(type(date_for_BARRA))
        #         low_filename,high_filename,time_leading=self.filename_list[idx]

        data_low = dpt.read_access_data(access_filename_pr, idx=time_leading)
        lr_raw = dpt.map_aust(data_low, domain=args.domain, xrarray=False)

        #         domain = [train_data.lon.data.min(), train_data.lon.data.max(), train_data.lat.data.min(), train_data.lat.data.max()]
        #         print(domain)

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        label = dpt.map_aust(data_high, domain=args.domain,
                             xrarray=False)  #,domain=domain)
        lr = dpt.interp_tensor_2d(lr_raw, (78, 100))
        if self.transform:  #channel 数量需要整理!!
            if self.args.channels == 27:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(lr_psl), self.transform(
                        lr_zg), self.transform(lr_tasmax), self.transform(
                            lr_tasmin), self.transform(label), torch.tensor(
                                int(date_for_BARRA.strftime(
                                    "%Y%m%d"))), torch.tensor(time_leading)
            elif self.args.channels == 5:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(lr_psl), self.transform(
                        lr_tasmax), self.transform(lr_tasmin), self.transform(
                            label), torch.tensor(
                                int(date_for_BARRA.strftime(
                                    "%Y%m%d"))), torch.tensor(time_leading)
            if self.args.channels == 2:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(label), torch.tensor(
                        int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                            time_leading)

        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)
예제 #4
0
    def __init__(self,
                 start_date=date(1990, 1, 1),
                 end_date=date(1990, 12, 31),
                 regin="AUS",
                 transform=None,
                 train=True,
                 args=None):
        if args is None:
            exit(0)
        self.file_BARRA_dir = args.file_BARRA_dir
        self.file_ACCESS_dir = args.file_ACCESS_dir

        self.transform = transform
        self.start_date = start_date
        self.end_date = end_date

        self.scale = args.scale[0]
        self.regin = regin
        self.leading_time = 217
        self.leading_time_we_use = 7

        #         if regin=="AUS":
        #             self.shape=(314,403,1,1)
        #             self.domain=[111.975, 156.275, -44.525, -9.975]
        #         else:
        #             self.shape=(768,1200,1,1)

        self.dates = self.date_range(start_date, end_date)

        self.filename_list = self.get_filename_with_time_order(
            args.file_ACCESS_dir + "pr/daily/")
        _, date_for_BARRA, time_leading = self.filename_list[0]

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        data_exp = dpt.map_aust(data_high, domain=args.domain,
                                xrarray=True)  #,domain=domain)
        self.lat = data_exp["lat"]
        self.lon = data_exp["lon"]
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        t = time.time()

        #read_data filemame[idx]
        en, access_date, barra_date, time_leading = self.filename_list[idx]

        lr = dpt.read_access_data(self.file_ACCESS_dir, en, access_date,
                                  time_leading, "pr")
        #        self.dem_data=self.to01(self.dem_data)*np.max(lr)
        #         lr=np.expand_dims(lr,axis=2)

        #         lr=np.expand_dims(self.mapping(lr),axis=2)
        label = dpt.read_barra_data_fc(self.file_BARRA_dir, barra_date)

        lr_zg = dpt.read_access_data(self.file_ACCESS_dir, en, access_date,
                                     time_leading, "zg")
        lr_zg = self.to01(lr_zg)
        lr_zg = lr_zg * np.max(lr)
        if self.args.psl:
            lr_psl = dpt.read_access_data(self.file_ACCESS_dir, en,
                                          access_date, time_leading, "psl")

        #lr_tasmax=dpt.read_access_data(self.file_ACCESS_dir,en,access_date,time_leading,"tasmax")

        #lr_tasmin=dpt.read_access_data(self.file_ACCESS_dir,en,access_date,time_leading,"tasmin")
        lr_t = dpt.read_access_data(self.file_ACCESS_dir, en, access_date,
                                    time_leading, "psl")
        lr_t = self.to01(lr_t)
        lr_t = lr_t * np.max(lr)
        #         if self.args.channels==1:
        #             lr=np.repeat(lr,3,axis=2)
        return self.lr_transform(Image.fromarray(lr)), self.hr_transform(
            Image.fromarray(label)
        ), self.lr_transform(Image.fromarray(lr_zg)), self.lr_transform(
            Image.fromarray(lr_t)), torch.tensor(int(en[1:])), torch.tensor(
                int(barra_date.strftime("%Y%m%d"))), torch.tensor(time_leading)
예제 #6
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        t = time.time()

        #read_data filemame[idx]
        access_filename_pr, en, access_date, date_for_BARRA, time_leading = self.filename_list[
            idx]
        #         print(type(date_for_BARRA))
        #         low_filename,high_filename,time_leading=self.filename_list[idx]

        lr = dpt.read_access_data(
            access_filename_pr, idx=time_leading).data[82:144, 134:188] * 86400
        #         lr=dpt.map_aust(lr,domain=self.args.domain,xrarray=False)
        lr = np.expand_dims(dpt.interp_tensor_2d(lr, self.shape), axis=2)
        lr.dtype = "float32"

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        label = dpt.map_aust(data_high, domain=self.args.domain,
                             xrarray=False)  #,domain=domain)

        if self.args.zg:
            access_filename_zg = self.args.file_ACCESS_dir + "zg/daily/" + en + "/" + "da_zg_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_zg = dpt.read_access_zg(access_filename_zg,
                                       idx=time_leading).data[:][83:145,
                                                                 135:188]
            lr_zg = dpt.interp_tensor_3d(lr_zg, self.shape)

        if self.args.psl:
            access_filename_psl = self.args.file_ACCESS_dir + "psl/daily/" + en + "/" + "da_psl_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_psl = dpt.read_access_data(access_filename_psl,
                                          var_name="psl",
                                          idx=time_leading).data[82:144,
                                                                 134:188]
            lr_psl = dpt.interp_tensor_2d(lr_psl, self.shape)

        if self.args.tasmax:
            access_filename_tasmax = self.args.file_ACCESS_dir + "tasmax/daily/" + en + "/" + "da_tasmax_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_tasmax = dpt.read_access_data(access_filename_tasmax,
                                             var_name="tasmax",
                                             idx=time_leading).data[82:144,
                                                                    134:188]
            lr_tasmax = dpt.interp_tensor_2d(lr_tasmax, self.shape)

        if self.args.tasmin:
            access_filename_tasmin = self.args.file_ACCESS_dir + "tasmin/daily/" + en + "/" + "da_tasmin_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_tasmin = dpt.read_access_data(access_filename_tasmin,
                                             var_name="tasmin",
                                             idx=time_leading).data[82:144,
                                                                    134:188]
            lr_tasmin = dpt.interp_tensor_2d(lr_tasmin, self.shape)

#         if self.args.dem:
# #             print("add dem data")
#             lr=np.concatenate((lr,np.expand_dims(self.dem_data,axis=2)),axis=2)

#         print("end loading one data,time cost %f"%(time.time()-t))

        if self.transform:  #channel 数量需要整理!!
            if self.args.channels == 27:
                return self.transform(lr), self.transform(
                    self.dem_data), self.transform(lr_psl), self.transform(
                        lr_zg), self.transform(lr_tasmax), self.transform(
                            lr_tasmin), self.transform(label), torch.tensor(
                                int(date_for_BARRA.strftime(
                                    "%Y%m%d"))), torch.tensor(time_leading)
            if self.args.channels == 2:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(label), torch.tensor(
                        int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                            time_leading)

        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)
예제 #7
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        #         t=time.time()

        #read_data filemame[idx]
        access_filename_pr, access_date, date_for_BARRA, time_leading = self.filename_list[
            idx]
        #         print(type(date_for_BARRA))
        #         low_filename,high_filename,time_leading=self.filename_list[idx]

        data_low = dpt.read_access_data(access_filename_pr, idx=time_leading)
        lr_raw = dpt.map_aust(data_low, domain=self.args.domain, xrarray=False)

        #         domain = [train_data.lon.data.min(), train_data.lon.data.max(), train_data.lat.data.min(), train_data.lat.data.max()]
        #         print(domain)

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        label = dpt.map_aust(data_high, domain=self.args.domain,
                             xrarray=False)  #,domain=domain)
        lr = dpt.interp_tensor_2d(lr_raw, (78, 100))

        if self.args.zg:
            access_filename_zg = self.args.file_ACCESS_dir + "zg/daily/" + en + "/" + "da_zg_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_zg = dpt.read_access_zg(access_filename_zg, idx=time_leading)
            data_zg_aus = map_aust(data_zg, xrarray=False)
            lr_zg = dpt.interp_tensor_3d(data_zg_aus, (78, 100))

        if self.args.psl:
            access_filename_psl = self.args.file_ACCESS_dir + "psl/daily/" + en + "/" + "da_psl_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_psl = dpt.read_access_data(access_filename_psl,
                                            idx=time_leading)
            data_psl_aus = map_aust(data_psl, xrarray=False)
            lr_psl = dpt.interp_tensor_2d(data_psl_aus, (78, 100))
        if self.args.tasmax:
            access_filename_tasmax = self.args.file_ACCESS_dir + "tasmax/daily/" + en + "/" + "da_tasmax_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_tasmax = dpt.read_access_data(access_filename_tasmax,
                                               idx=time_leading)
            data_tasmax_aus = map_aust(data_tasmax, xrarray=False)
            lr_tasmax = dpt.interp_tensor_2d(data_tasmax_aus, (78, 100))

        if self.args.tasmax:
            access_filename_tasmin = self.args.file_ACCESS_dir + "tasmin/daily/" + en + "/" + "da_tasmin_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_tasmin = dpt.read_access_data(access_filename_tasmin,
                                               idx=time_leading)
            data_tasmin_aus = map_aust(data_tasmin, xrarray=False)
            lr_tasmin = dpt.interp_tensor_2d(data_tasmin_aus, (78, 100))

#         print("end loading one data,time cost %f"%(time.time()-t))

        if self.transform:  #channel 数量需要整理!!
            return self.transform(
                lr * 86400), self.transform(label), torch.tensor(
                    int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                        time_leading)
        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)