def main(): # pre_train_path="./model/save/temp01/"+0+".pth" init_date=date(1970, 1, 1) start_date=date(1990, 1, 2) end_date=date(1990,12,25) # end_date=date(2012,12,25) #if 929 is true we should substract 1 day sys = platform.system() if sys == "Windows": init_date=date(1970, 1, 1) start_date=date(1990, 1, 2) end_date=date(1990,12,15) #if 929 is true we should substract 1 day args.file_ACCESS_dir="H:/climate/access-s1/" args.file_BARRA_dir="D:/dataset/accum_prcp/" # args.file_ACCESS_dir="E:/climate/access-s1/" # args.file_BARRA_dir="C:/Users/JIA059/barra/" args.file_DEM_dir="../DEM/" else: args.file_ACCESS_dir_pr="/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/" args.file_ACCESS_dir="/g/data/ub7/access-s1/hc/raw_model/atmos/" # training_name="temp01" args.file_BARRA_dir="/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/" args.channels=0 if args.pr: args.channels+=1 if args.zg: args.channels+=1 if args.psl: args.channels+=1 if args.tasmax: args.channels+=1 if args.tasmin: args.channels+=1 if args.dem: args.channels+=1 access_rgb_mean= 2.9067910245780248e-05*86400 pre_train_path="./model/save/"+args.train_name+"/last_"+str(args.channels)+".pth" leading_time=217 args.leading_time_we_use=1 args.ensemble=1 print(access_rgb_mean) print("training statistics:") print(" ------------------------------") print(" trainning name | %s"%args.train_name) print(" ------------------------------") print(" num of channels | %5d"%args.channels) print(" ------------------------------") print(" num of threads | %5d"%args.n_threads) print(" ------------------------------") print(" batch_size | %5d"%args.batch_size) print(" ------------------------------") print(" using cpu only? | %5d"%args.cpu) ############################################################################################ train_transforms = transforms.Compose([ # transforms.Resize(IMG_SIZE), # transforms.RandomResizedCrop(IMG_SIZE), # transforms.RandomHorizontalFlip(), # transforms.RandomRotation(30), transforms.ToTensor() # transforms.Normalize(IMG_MEAN, IMG_STD) ]) data_set=ACCESS_BARRA_v4(start_date,end_date,transform=train_transforms,args=args) train_data,val_data=random_split(data_set,[int(len(data_set)*0.8),len(data_set)-int(len(data_set)*0.8)]) print("Dataset statistics:") print(" ------------------------------") print(" total | %5d"%len(data_set)) print(" ------------------------------") print(" train | %5d"%len(train_data)) print(" ------------------------------") print(" val | %5d"%len(val_data)) ###################################################################################set a the dataLoader train_dataloders =DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.n_threads) val_dataloders =DataLoader(val_data, batch_size=args.batch_size, shuffle=True, num_workers=args.n_threads) ## def prepare( l, volatile=False): def _prepare(tensor): if args.precision == 'half': tensor = tensor.half() if args.precision == 'single': tensor = tensor.float() return tensor.to(device) return [_prepare(_l) for _l in l] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") checkpoint = utility.checkpoint(args) net = model.Model(args, checkpoint) # net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=True) my_net=my_model.Modify_RCAN(net,args,checkpoint) # net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=args.cpu) args.lr=0.001 criterion = nn.L1Loss() optimizer_my = optim.SGD(my_net.parameters(), lr=args.lr, momentum=0.9) # scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1) scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9) # torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1) if args.resume==1: print("continue last train") model_checkpoint = torch.load(pre_train_path,map_location=device) else: print("restart train") model_checkpoint = torch.load("./model/save/"+args.train_name+"/first_"+str(args.channels)+".pth",map_location=device) my_net.load_state_dict(model_checkpoint['model']) optimizer_my.load_state_dict(model_checkpoint['optimizer']) epoch = model_checkpoint['epoch'] if torch.cuda.device_count() > 1: write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs my_net = nn.DataParallel(my_net) else: write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!") # my_net = torch.nn.DataParallel(my_net) my_net.to(device) ##########################################################################training if args.channels==1: write_log("start") max_error=np.inf for e in range(args.epochs): #train my_net.train() loss=0 start=time.time() for batch, (pr,hr,_,_) in enumerate(train_dataloders): write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time())) # start=time.time() pr,hr= prepare([pr,hr]) optimizer_my.zero_grad() with torch.set_grad_enabled(True): sr = my_net(pr) print(pr.shape) print(sr.shape) running_loss =criterion(sr, hr) running_loss.backward() optimizer_my.step() loss+=running_loss #.copy()? if batch%10==0: state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e} torch.save(state, "./model/save/temp01/last.pth") write_log("Train done,train time cost %f s,learning rate:%f, loss: %f"%(start-time.time(),optimizer_my.state_dict()['param_groups'][0]['lr'] ,running_loss.item() )) start=time.time() #validation my_net.eval() start=time.time() with torch.no_grad(): eval_psnr=0 eval_ssim=0 # tqdm_val = tqdm(val_dataloders, ncols=80) for batch, (pr,hr,_,_) in enumerate(val_dataloders): pr,hr = prepare([pr,hr]) sr = my_net(pr,dem) val_loss=criterion(sr, hr) for ssr,hhr in zip(sr,hr): eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() ) eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() ) write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%( e, time.time()-start, optimizer_my.state_dict()['param_groups'][0]['lr'], loss.item()/len(train_data), val_loss )) # print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%( # e, # time.time()-start, # optimizer_my.state_dict()['param_groups'][0]['lr'], # loss.item()/len(train_data), # val_loss # )) if running_loss<max_error: max_error=running_loss # torch.save(net,train_loss"_"+str(e)+".pkl") if not os.path.exists("./model/save/"+args.train_name+"/"): os.mkdir("./model/save/"+args.train_name+"/") write_log("saving") state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e} torch.save(state, "./model/save/temp01/"+str(e)+".pth") # torch.save(net,"./model/save/"+args.train_name+"/"+str(e)+".pkl") if args.channels==2: write_log("start") max_error=np.inf for e in range(args.epochs): #train my_net.train() loss=0 start=time.time() for batch, (pr,dem,hr,_,_) in enumerate(train_dataloders): write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time())) start=time.time() pr,dem,hr= prepare([pr,dem,hr]) optimizer_my.zero_grad() with torch.set_grad_enabled(True): sr = my_net(pr,dem) running_loss =criterion(sr, hr) running_loss.backward() optimizer_my.step() loss+=running_loss #.copy()? if batch%10==0: state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e} torch.save(state, "./model/save/temp01/last_"+str(args.channels)+".pth") write_log("Train done,train time cost %f s,loss: %f"%(start-time.time(),running_loss.item() )) start=time.time() #validation my_net.eval() start=time.time() with torch.no_grad(): eval_psnr=0 eval_ssim=0 # tqdm_val = tqdm(val_dataloders, ncols=80) for idx_img, (pr,dem,hr,_,_) in enumerate(val_dataloders): pr,dem,hr = prepare([pr,dem,hr]) sr = my_net(pr,dem) val_loss=criterion(sr, hr) for ssr,hhr in zip(sr,hr): eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() ) eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() ) write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%( e, time.time()-start, optimizer_my.state_dict()['param_groups'][0]['lr'], loss.item()/len(train_data), val_loss )) # print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%( # e, # time.time()-start, # optimizer_my.state_dict()['param_groups'][0]['lr'], # loss.item()/len(train_data), # val_loss # )) if running_loss<max_error: max_error=running_loss # torch.save(net,train_loss"_"+str(e)+".pkl") if not os.path.exists("./model/save/"+args.train_name+"/"): os.mkdir("./model/save/"+args.train_name+"/") write_log("saving") state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e} torch.save(state, "./model/save/temp01/"+str(e)+".pth") # torch.save(net,"./model/save/"+args.train_name+"/"+str(e)+".pkl") if args.channels==3: write_log("start") max_error=np.inf for e in range(args.epochs): #train my_net.train() loss=0 start=time.time() for batch, (pr,dem,tasmax,hr,_,_) in enumerate(train_dataloders): write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time())) start=time.time() pr,dem,tasmax,hr= prepare([pr,dem,tasmax,hr]) optimizer_my.zero_grad() with torch.set_grad_enabled(True): sr = my_net(pr,dem) running_loss =criterion(sr, hr,tasmax=tasmax) running_loss.backward() optimizer_my.step() loss+=running_loss #.copy()? if batch%10==0: state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e} torch.save(state, "./model/save/temp01/last.pth") write_log("Train done,train time cost %f s,loss: %f"%(start-time.time(),running_loss.item() )) start=time.time() #validation my_net.eval() start=time.time() with torch.no_grad(): eval_psnr=0 eval_ssim=0 # tqdm_val = tqdm(val_dataloders, ncols=80) for idx_img, (pr,dem,psl,zg,tasmax,tasmin, hr,_,_) in enumerate(val_dataloders): pr,dem,psl,zg,tasmax,tasmin, hr = prepare([pr,dem,psl,zg,tasmax,tasmin, hr]) sr = my_net(pr,dem,psl,zg,tasmax,tasmin) val_loss=criterion(sr, hr) for ssr,hhr in zip(sr,hr): eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() ) eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() ) write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%( e, time.time()-start, optimizer_my.state_dict()['param_groups'][0]['lr'], loss.item()/len(train_data), val_loss )) # print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%( # e, # time.time()-start, # optimizer_my.state_dict()['param_groups'][0]['lr'], # loss.item()/len(train_data), # val_loss # )) if running_loss<max_error: max_error=running_loss # torch.save(net,train_loss"_"+str(e)+".pkl") if not os.path.exists("./model/save/"+args.train_name+"/"): os.mkdir("./model/save/"+args.train_name+"/") write_log("saving") state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e} torch.save(state, "./model/save/temp01/"+str(e)+".pth") # torch.save(net,"./model/save/"+args.train_name+"/"+str(e)+".pkl") else: write_log("start") max_error=np.inf for e in range(args.epochs): #train my_net.train() loss=0 start=time.time() for batch, (pr,dem,psl,zg,tasmax,tasmin, hr,_,_) in enumerate(train_dataloders): write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time())) start=time.time() pr,dem,psl,zg,tasmax,tasmin, hr = prepare([pr,dem,psl,zg,tasmax,tasmin, hr]) optimizer_my.zero_grad() with torch.set_grad_enabled(True): sr = my_net(pr,dem,psl,zg,tasmax,tasmin) running_loss =criterion(sr, hr) running_loss.backward() optimizer_my.step() loss+=running_loss #.copy()? if batch%10==0: state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e} torch.save(state, "./model/save/temp01/last.pth") write_log("Train done,train time cost %f s,loss: %f"%(start-time.time(),running_loss.item() )) start=time.time() #validation my_net.eval() start=time.time() with torch.no_grad(): eval_psnr=0 eval_ssim=0 # tqdm_val = tqdm(val_dataloders, ncols=80) for idx_img, (pr,dem,psl,zg,tasmax,tasmin, hr,_,_) in enumerate(val_dataloders): pr,dem,psl,zg,tasmax,tasmin, hr = prepare([pr,dem,psl,zg,tasmax,tasmin, hr]) sr = my_net(pr,dem,psl,zg,tasmax,tasmin) val_loss=criterion(sr, hr) for ssr,hhr in zip(sr,hr): eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() ) eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() ) write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%( e, time.time()-start, optimizer_my.state_dict()['param_groups'][0]['lr'], loss.item()/len(train_data), val_loss )) # print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%( # e, # time.time()-start, # optimizer_my.state_dict()['param_groups'][0]['lr'], # loss.item()/len(train_data), # val_loss # )) if running_loss<max_error: max_error=running_loss # torch.save(net,train_loss"_"+str(e)+".pkl") if not os.path.exists("./model/save/"+args.train_name+"/"): os.mkdir("./model/save/"+args.train_name+"/") write_log("saving") state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e} torch.save(state, "./model/save/temp01/"+str(e)+".pth")
if args.precision == 'half': tensor = tensor.half() if args.precision == 'single': tensor = tensor.float() return tensor.to(device) return [_prepare(_l) for _l in l] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") checkpoint = utility.checkpoint(args) net = model.Model(args, checkpoint) net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=args.cpu) my_net = my_model.Modify_RCAN(net, args, checkpoint) args.lr = 0.001 criterion = nn.L1Loss() optimizer_my = optim.SGD(my_net.parameters(), lr=args.lr, momentum=0.9) # scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1) scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9) # torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1) # if torch.cuda.device_count() > 1: # write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!") # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs # net = nn.DataParallel(net) # else: # write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!") # net.to(device)
def main(): pre_train_path=args.continue_train init_date=date(1970, 1, 1) start_date=date(1990, 1, 2) end_date=date(2011,12,25) # end_date=date(2012,12,25) #if 929 is true we should substract 1 day sys = platform.system() args.file_ACCESS_dir="../data/" args.file_BARRA_dir="../data/barra_aus/" # if sys == "Windows": # init_date=date(1970, 1, 1) # start_date=date(1990, 1, 2) # end_date=date(1990,12,15) #if 929 is true we should substract 1 day # args.cpu=True # # args.file_ACCESS_dir="E:/climate/access-s1/" # # args.file_BARRA_dir="C:/Users/JIA059/barra/" # args.file_DEM_dir="../DEM/" # else: # args.file_ACCESS_dir_pr="/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/" # args.file_ACCESS_dir="/g/data/ub7/access-s1/hc/raw_model/atmos/" # # training_name="temp01" # args.file_BARRA_dir="/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/" args.channels=0 if args.pr: args.channels+=1 if args.zg: args.channels+=1 if args.psl: args.channels+=1 if args.tasmax: args.channels+=1 if args.tasmin: args.channels+=1 if args.dem: args.channels+=1 leading_time=217 args.leading_time_we_use=1 args.ensemble=11 pre_train_path="./model/prprpr/best.pth" ## def prepare( l, volatile=False): def _prepare(tensor): if args.precision == 'half': tensor = tensor.half() if args.precision == 'single': tensor = tensor.float() return tensor.to(device) return [_prepare(_l) for _l in l] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") checkpoint = utility.checkpoint(args) net = model.Model(args, checkpoint) # net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=True) if not args.prprpr: print("no prprprprprrpprprpprpprrp") net=my_model.Modify_RCAN(net,args,checkpoint) args.lr=0.00001 criterion = nn.L1Loss() # criterion=nn.MSELoss() optimizer_my = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9) # scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1) scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9) if pre_train_path!=".": write_log("load last train from"+pre_train_path) model_checkpoint = torch.load(pre_train_path,map_location=device) net.load_state_dict(model_checkpoint['model']) # net.load(pre_train_path) optimizer_my.load_state_dict(model_checkpoint['optimizer']) epoch = model_checkpoint['epoch'] scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9,last_epoch=epoch) print(scheduler.state_dict()) #demo input precipitation data(pr) def add_lat_lon_data(data,domain=[112.9, 154.00, -43.7425, -9.0],xarray=True): "data: is the something you want to add lat and lon, with first demenstion is lat,second dimention is lon,domain is DEM domain " new_lon=np.linspace(domain[0],domain[1],data.shape[1]) new_lat=np.linspace(domain[2],domain[3],data.shape[0]) if xarray: return xr.DataArray(data,coords=[new_lat,new_lon],dims=["lat","lon"]) else: return data,new_lat,new_lon demo_date=date(1990,1,25) idx=0 ensamble_demo="e01" file="../data/" pr=np.expand_dims(np.repeat(np.expand_dims(dpt.read_access_data(file,ensamble_demo,demo_date,idx),axis=0),3,axis=0),axis=0) print(pr.shape) pr=prepare([torch.tensor(pr)]) hr=net(pr[0],0).cpu().detach().numpy() print(np.squeeze(hr[:,1]).shape) title="test \n date: "+(demo_date+timedelta(idx)).strftime("%Y%m%d") # prec_in=dpt.read_access_data(filename,idx=idx)*86400 hr,lat,lon=add_lat_lon_data(np.squeeze(hr[:,1]),xarray=False) # print(hr) dpt.draw_aus(hr,lat,lon,title=title,save=True,path="test") # print(prec_in.shape[0],prec_in.shape[1]) # torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1) # if args.resume==1: # print("continue last train") # model_checkpoint = torch.load(pre_train_path,map_location=device) # else: # print("restart train") # model_checkpoint = torch.load("./model/save/"+args.train_name+"/first_"+str(args.channels)+".pth",map_location=device) # my_net.load_state_dict(model_checkpoint['model']) # optimizer_my.load_state_dict(model_checkpoint['optimizer']) # epoch = model_checkpoint['epoch'] if torch.cuda.device_count() > 1: write_log("!!!!!!!!!!!!!Let's use"+str(torch.cuda.device_count())+"GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs net = nn.DataParallel(net,range(torch.cuda.device_count())) else: write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!") # my_net = torch.nn.DataParallel(my_net) net.to(device)