Пример #1
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        t = time.time()

        #read_data filemame[idx]
        en, access_date, barra_date, time_leading = self.filename_list[idx]

        lr = dpt.read_access_data(self.file_ACCESS_dir, en, access_date,
                                  time_leading, "pr")
        lr = np.expand_dims(lr, axis=2)

        #         lr=np.expand_dims(self.mapping(lr),axis=2)
        label = dpt.read_barra_data_fc(self.file_BARRA_dir, barra_date)

        if self.args.zg:
            lr_zg = np.expand_dims(dpt.read_access_data(
                self.file_ACCESS_dir, en, access_date, time_leading, "zg"),
                                   axis=2)

        if self.args.psl:
            lr_psl = dpt.read_access_data(self.file_ACCESS_dir, en,
                                          access_date, time_leading, "psl")

        if self.args.tasmax:
            lr_tasmax = np.expand_dims(dpt.read_access_data(
                self.file_ACCESS_dir, en, access_date, time_leading, "tasmax"),
                                       axis=2)

        if self.args.tasmin:
            lr_tasmin = dpt.read_access_data(self.file_ACCESS_dir, en,
                                             access_date, time_leading,
                                             "tasmin")


#         if self.args.channels==1:
#             lr=np.repeat(lr,3,axis=2)

        if self.transform:  #channel 数量需要整理!!

            return self.transform(lr), self.transform(
                self.dem_data), self.transform(label), torch.tensor(
                    int(barra_date.strftime("%Y%m%d"))), torch.tensor(
                        time_leading)
        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        t = time.time()

        #read_data filemame[idx]
        en, access_date, barra_date, time_leading = self.filename_list[idx]

        lr = dpt.read_access_data(self.file_ACCESS_dir, en, access_date,
                                  time_leading, "pr")
        #        self.dem_data=self.to01(self.dem_data)*np.max(lr)
        #         lr=np.expand_dims(lr,axis=2)

        #         lr=np.expand_dims(self.mapping(lr),axis=2)
        label = dpt.read_barra_data_fc(self.file_BARRA_dir, barra_date)

        lr_zg = dpt.read_access_data(self.file_ACCESS_dir, en, access_date,
                                     time_leading, "zg")
        lr_zg = self.to01(lr_zg)
        lr_zg = lr_zg * np.max(lr)
        if self.args.psl:
            lr_psl = dpt.read_access_data(self.file_ACCESS_dir, en,
                                          access_date, time_leading, "psl")

        #lr_tasmax=dpt.read_access_data(self.file_ACCESS_dir,en,access_date,time_leading,"tasmax")

        #lr_tasmin=dpt.read_access_data(self.file_ACCESS_dir,en,access_date,time_leading,"tasmin")
        lr_t = dpt.read_access_data(self.file_ACCESS_dir, en, access_date,
                                    time_leading, "psl")
        lr_t = self.to01(lr_t)
        lr_t = lr_t * np.max(lr)
        #         if self.args.channels==1:
        #             lr=np.repeat(lr,3,axis=2)
        return self.lr_transform(Image.fromarray(lr)), self.hr_transform(
            Image.fromarray(label)
        ), self.lr_transform(Image.fromarray(lr_zg)), self.lr_transform(
            Image.fromarray(lr_t)), torch.tensor(int(en[1:])), torch.tensor(
                int(barra_date.strftime("%Y%m%d"))), torch.tensor(time_leading)
Пример #3
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        #read_data filemame[idx]
        access_filename_pr, date_for_BARRA, time_leading = self.filename_list[
            idx]
        #         print(type(date_for_BARRA))
        #         low_filename,high_filename,time_leading=self.filename_list[idx]

        data_low = dpt.read_access_data(access_filename_pr, idx=time_leading)
        lr_raw = dpt.map_aust(data_low, domain=args.domain, xrarray=False)

        #         domain = [train_data.lon.data.min(), train_data.lon.data.max(), train_data.lat.data.min(), train_data.lat.data.max()]
        #         print(domain)

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        label = dpt.map_aust(data_high, domain=args.domain,
                             xrarray=False)  #,domain=domain)
        lr = dpt.interp_tensor_2d(lr_raw, (78, 100))
        if self.transform:  #channel 数量需要整理!!
            if self.args.channels == 27:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(lr_psl), self.transform(
                        lr_zg), self.transform(lr_tasmax), self.transform(
                            lr_tasmin), self.transform(label), torch.tensor(
                                int(date_for_BARRA.strftime(
                                    "%Y%m%d"))), torch.tensor(time_leading)
            elif self.args.channels == 5:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(lr_psl), self.transform(
                        lr_tasmax), self.transform(lr_tasmin), self.transform(
                            label), torch.tensor(
                                int(date_for_BARRA.strftime(
                                    "%Y%m%d"))), torch.tensor(time_leading)
            if self.args.channels == 2:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(label), torch.tensor(
                        int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                            time_leading)

        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)
Пример #4
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        #read_data filemame[idx]
        access_filename, date_for_BARRA, time_leading = self.filename_list[idx]
        data_low = dpt.read_access_data(access_filename, idx=time_leading)
        lr_raw = dpt.map_aust(data_low, domain=args.domain, xrarray=False)

        lr = dpt.interp_tensor_2d(lr_raw, (78, 100))

        if self.transform:
            return self.transform(np.expand_dims(lr, axis=3) * 86400)
        else:
            return np.expand_dims(lr, axis=3) * 86400, np.expand_dims(
                label,
                axis=3), torch.tensor(int(date_for_BARRA.strftime(
                    "%Y%m%d"))), torch.tensor(time_leading)
Пример #5
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        t = time.time()

        #read_data filemame[idx]
        access_filename_pr, en, access_date, date_for_BARRA, time_leading = self.filename_list[
            idx]
        #         print(type(date_for_BARRA))
        #         low_filename,high_filename,time_leading=self.filename_list[idx]

        lr = dpt.read_access_data(
            access_filename_pr, idx=time_leading).data[82:144, 134:188] * 86400
        #         lr=dpt.map_aust(lr,domain=self.args.domain,xrarray=False)
        lr = np.expand_dims(dpt.interp_tensor_2d(lr, self.shape), axis=2)
        lr.dtype = "float32"

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        label = dpt.map_aust(data_high, domain=self.args.domain,
                             xrarray=False)  #,domain=domain)

        if self.args.zg:
            access_filename_zg = self.args.file_ACCESS_dir + "zg/daily/" + en + "/" + "da_zg_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_zg = dpt.read_access_zg(access_filename_zg,
                                       idx=time_leading).data[:][83:145,
                                                                 135:188]
            lr_zg = dpt.interp_tensor_3d(lr_zg, self.shape)

        if self.args.psl:
            access_filename_psl = self.args.file_ACCESS_dir + "psl/daily/" + en + "/" + "da_psl_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_psl = dpt.read_access_data(access_filename_psl,
                                          var_name="psl",
                                          idx=time_leading).data[82:144,
                                                                 134:188]
            lr_psl = dpt.interp_tensor_2d(lr_psl, self.shape)

        if self.args.tasmax:
            access_filename_tasmax = self.args.file_ACCESS_dir + "tasmax/daily/" + en + "/" + "da_tasmax_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_tasmax = dpt.read_access_data(access_filename_tasmax,
                                             var_name="tasmax",
                                             idx=time_leading).data[82:144,
                                                                    134:188]
            lr_tasmax = dpt.interp_tensor_2d(lr_tasmax, self.shape)

        if self.args.tasmin:
            access_filename_tasmin = self.args.file_ACCESS_dir + "tasmin/daily/" + en + "/" + "da_tasmin_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            lr_tasmin = dpt.read_access_data(access_filename_tasmin,
                                             var_name="tasmin",
                                             idx=time_leading).data[82:144,
                                                                    134:188]
            lr_tasmin = dpt.interp_tensor_2d(lr_tasmin, self.shape)

#         if self.args.dem:
# #             print("add dem data")
#             lr=np.concatenate((lr,np.expand_dims(self.dem_data,axis=2)),axis=2)

#         print("end loading one data,time cost %f"%(time.time()-t))

        if self.transform:  #channel 数量需要整理!!
            if self.args.channels == 27:
                return self.transform(lr), self.transform(
                    self.dem_data), self.transform(lr_psl), self.transform(
                        lr_zg), self.transform(lr_tasmax), self.transform(
                            lr_tasmin), self.transform(label), torch.tensor(
                                int(date_for_BARRA.strftime(
                                    "%Y%m%d"))), torch.tensor(time_leading)
            if self.args.channels == 2:
                return self.transform(lr * 86400), self.transform(
                    self.dem_data), self.transform(label), torch.tensor(
                        int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                            time_leading)

        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)
Пример #6
0
def main():
    pre_train_path=args.continue_train
    

    init_date=date(1970, 1, 1)
    start_date=date(1990, 1, 2)
    end_date=date(2011,12,25)
#     end_date=date(2012,12,25) #if 929 is true we should substract 1 day    
    sys = platform.system()
    args.file_ACCESS_dir="../data/"
    args.file_BARRA_dir="../data/barra_aus/"
#     if sys == "Windows":
#         init_date=date(1970, 1, 1)
#         start_date=date(1990, 1, 2)
#         end_date=date(1990,12,15) #if 929 is true we should substract 1 day   
#         args.cpu=True
# #         args.file_ACCESS_dir="E:/climate/access-s1/"
# #         args.file_BARRA_dir="C:/Users/JIA059/barra/"
#         args.file_DEM_dir="../DEM/"
#     else:
#         args.file_ACCESS_dir_pr="/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/"
#         args.file_ACCESS_dir="/g/data/ub7/access-s1/hc/raw_model/atmos/"
#         # training_name="temp01"
#         args.file_BARRA_dir="/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/"

    args.channels=0
    if args.pr:
        args.channels+=1
    if args.zg:
        args.channels+=1
    if args.psl:
        args.channels+=1
    if args.tasmax:
        args.channels+=1
    if args.tasmin:
        args.channels+=1
    if args.dem:
        args.channels+=1
    leading_time=217
    args.leading_time_we_use=1
    args.ensemble=11
    pre_train_path="./model/prprpr/best.pth"
    
    ##
    def prepare( l, volatile=False):
        def _prepare(tensor):
            if args.precision == 'half': tensor = tensor.half()
            if args.precision == 'single': tensor = tensor.float()
            return tensor.to(device)

        return [_prepare(_l) for _l in l]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    checkpoint = utility.checkpoint(args)
    net = model.Model(args, checkpoint)
#     net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=True)
    if not args.prprpr:
        
        print("no prprprprprrpprprpprpprrp")
        net=my_model.Modify_RCAN(net,args,checkpoint)


    
    args.lr=0.00001
    criterion = nn.L1Loss()
#     criterion=nn.MSELoss()

    optimizer_my = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9)
    # scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9)
    
    if pre_train_path!=".":
        write_log("load last train from"+pre_train_path)
        model_checkpoint = torch.load(pre_train_path,map_location=device)
        net.load_state_dict(model_checkpoint['model'])
#         net.load(pre_train_path)
        optimizer_my.load_state_dict(model_checkpoint['optimizer'])
        epoch = model_checkpoint['epoch']
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9,last_epoch=epoch)
        print(scheduler.state_dict())

#demo input precipitation data(pr)

        def add_lat_lon_data(data,domain=[112.9, 154.00, -43.7425, -9.0],xarray=True):
            "data: is the something you want to add lat and lon, with first demenstion is lat,second dimention is lon,domain is DEM domain "
            new_lon=np.linspace(domain[0],domain[1],data.shape[1])
            new_lat=np.linspace(domain[2],domain[3],data.shape[0])
            if xarray:
                return xr.DataArray(data,coords=[new_lat,new_lon],dims=["lat","lon"])
            else:
                return data,new_lat,new_lon

        demo_date=date(1990,1,25)
        idx=0
        ensamble_demo="e01"
        file="../data/"
        pr=np.expand_dims(np.repeat(np.expand_dims(dpt.read_access_data(file,ensamble_demo,demo_date,idx),axis=0),3,axis=0),axis=0)
        print(pr.shape)

        pr=prepare([torch.tensor(pr)])

        hr=net(pr[0],0).cpu().detach().numpy()
        print(np.squeeze(hr[:,1]).shape)


        title="test \n date: "+(demo_date+timedelta(idx)).strftime("%Y%m%d")
        # prec_in=dpt.read_access_data(filename,idx=idx)*86400
        hr,lat,lon=add_lat_lon_data(np.squeeze(hr[:,1]),xarray=False)
        # print(hr)
        dpt.draw_aus(hr,lat,lon,title=title,save=True,path="test")
        # print(prec_in.shape[0],prec_in.shape[1])        
        
    
    # torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1)
    
#     if args.resume==1:
#         print("continue last train")
#         model_checkpoint = torch.load(pre_train_path,map_location=device)
#     else:
#         print("restart train")
#         model_checkpoint = torch.load("./model/save/"+args.train_name+"/first_"+str(args.channels)+".pth",map_location=device)

#     my_net.load_state_dict(model_checkpoint['model'])
#     optimizer_my.load_state_dict(model_checkpoint['optimizer'])
#     epoch = model_checkpoint['epoch']
    
    if torch.cuda.device_count() > 1:
        write_log("!!!!!!!!!!!!!Let's use"+str(torch.cuda.device_count())+"GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        net = nn.DataParallel(net,range(torch.cuda.device_count()))
    else:
        write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!")

#     my_net = torch.nn.DataParallel(my_net)
    net.to(device)
Пример #7
0
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        #         t=time.time()

        #read_data filemame[idx]
        access_filename_pr, access_date, date_for_BARRA, time_leading = self.filename_list[
            idx]
        #         print(type(date_for_BARRA))
        #         low_filename,high_filename,time_leading=self.filename_list[idx]

        data_low = dpt.read_access_data(access_filename_pr, idx=time_leading)
        lr_raw = dpt.map_aust(data_low, domain=self.args.domain, xrarray=False)

        #         domain = [train_data.lon.data.min(), train_data.lon.data.max(), train_data.lat.data.min(), train_data.lat.data.max()]
        #         print(domain)

        data_high = dpt.read_barra_data_fc(self.file_BARRA_dir,
                                           date_for_BARRA,
                                           nine2nine=True)
        label = dpt.map_aust(data_high, domain=self.args.domain,
                             xrarray=False)  #,domain=domain)
        lr = dpt.interp_tensor_2d(lr_raw, (78, 100))

        if self.args.zg:
            access_filename_zg = self.args.file_ACCESS_dir + "zg/daily/" + en + "/" + "da_zg_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_zg = dpt.read_access_zg(access_filename_zg, idx=time_leading)
            data_zg_aus = map_aust(data_zg, xrarray=False)
            lr_zg = dpt.interp_tensor_3d(data_zg_aus, (78, 100))

        if self.args.psl:
            access_filename_psl = self.args.file_ACCESS_dir + "psl/daily/" + en + "/" + "da_psl_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_psl = dpt.read_access_data(access_filename_psl,
                                            idx=time_leading)
            data_psl_aus = map_aust(data_psl, xrarray=False)
            lr_psl = dpt.interp_tensor_2d(data_psl_aus, (78, 100))
        if self.args.tasmax:
            access_filename_tasmax = self.args.file_ACCESS_dir + "tasmax/daily/" + en + "/" + "da_tasmax_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_tasmax = dpt.read_access_data(access_filename_tasmax,
                                               idx=time_leading)
            data_tasmax_aus = map_aust(data_tasmax, xrarray=False)
            lr_tasmax = dpt.interp_tensor_2d(data_tasmax_aus, (78, 100))

        if self.args.tasmax:
            access_filename_tasmin = self.args.file_ACCESS_dir + "tasmin/daily/" + en + "/" + "da_tasmin_" + access_date.strftime(
                "%Y%m%d") + "_" + en + ".nc"
            data_tasmin = dpt.read_access_data(access_filename_tasmin,
                                               idx=time_leading)
            data_tasmin_aus = map_aust(data_tasmin, xrarray=False)
            lr_tasmin = dpt.interp_tensor_2d(data_tasmin_aus, (78, 100))

#         print("end loading one data,time cost %f"%(time.time()-t))

        if self.transform:  #channel 数量需要整理!!
            return self.transform(
                lr * 86400), self.transform(label), torch.tensor(
                    int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                        time_leading)
        else:
            return lr * 86400, label, torch.tensor(
                int(date_for_BARRA.strftime("%Y%m%d"))), torch.tensor(
                    time_leading)