def main():
    
#     pre_train_path="./model/save/temp01/"+0+".pth"

    
    
    init_date=date(1970, 1, 1)
    start_date=date(1990, 1, 2)
    end_date=date(1990,12,25)
#     end_date=date(2012,12,25) #if 929 is true we should substract 1 day    
    sys = platform.system()
    
    if sys == "Windows":
        init_date=date(1970, 1, 1)
        start_date=date(1990, 1, 2)
        end_date=date(1990,12,15) #if 929 is true we should substract 1 day   
        args.file_ACCESS_dir="H:/climate/access-s1/" 
        args.file_BARRA_dir="D:/dataset/accum_prcp/"
#         args.file_ACCESS_dir="E:/climate/access-s1/"
#         args.file_BARRA_dir="C:/Users/JIA059/barra/"
        args.file_DEM_dir="../DEM/"
    else:
        args.file_ACCESS_dir_pr="/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/"
        args.file_ACCESS_dir="/g/data/ub7/access-s1/hc/raw_model/atmos/"
        # training_name="temp01"
        args.file_BARRA_dir="/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/"

    args.channels=0
    if args.pr:
        args.channels+=1
    if args.zg:
        args.channels+=1
    if args.psl:
        args.channels+=1
    if args.tasmax:
        args.channels+=1
    if args.tasmin:
        args.channels+=1
    if args.dem:
        args.channels+=1
    access_rgb_mean= 2.9067910245780248e-05*86400
    pre_train_path="./model/save/"+args.train_name+"/last_"+str(args.channels)+".pth"
    leading_time=217
    args.leading_time_we_use=1
    args.ensemble=1


    print(access_rgb_mean)

    print("training statistics:")
    print("  ------------------------------")
    print("  trainning name  |  %s"%args.train_name)
    print("  ------------------------------")
    print("  num of channels | %5d"%args.channels)
    print("  ------------------------------")
    print("  num of threads  | %5d"%args.n_threads)
    print("  ------------------------------")
    print("  batch_size     | %5d"%args.batch_size)
    print("  ------------------------------")
    print("  using cpu only? | %5d"%args.cpu)

    ############################################################################################

    train_transforms = transforms.Compose([
    #     transforms.Resize(IMG_SIZE),
    #     transforms.RandomResizedCrop(IMG_SIZE),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomRotation(30),
        transforms.ToTensor()
    #     transforms.Normalize(IMG_MEAN, IMG_STD)
    ])

    data_set=ACCESS_BARRA_v4(start_date,end_date,transform=train_transforms,args=args)
    train_data,val_data=random_split(data_set,[int(len(data_set)*0.8),len(data_set)-int(len(data_set)*0.8)])


    print("Dataset statistics:")
    print("  ------------------------------")
    print("  total | %5d"%len(data_set))
    print("  ------------------------------")
    print("  train | %5d"%len(train_data))
    print("  ------------------------------")
    print("  val   | %5d"%len(val_data))

    ###################################################################################set a the dataLoader
    train_dataloders =DataLoader(train_data,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                num_workers=args.n_threads)
    val_dataloders =DataLoader(val_data,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                              num_workers=args.n_threads)
    ##
    def prepare( l, volatile=False):
        def _prepare(tensor):
            if args.precision == 'half': tensor = tensor.half()
            if args.precision == 'single': tensor = tensor.float()
            return tensor.to(device)

        return [_prepare(_l) for _l in l]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    checkpoint = utility.checkpoint(args)
    net = model.Model(args, checkpoint)
#     net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=True)
    my_net=my_model.Modify_RCAN(net,args,checkpoint)

#     net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=args.cpu)
    
    args.lr=0.001
    criterion = nn.L1Loss()
    optimizer_my = optim.SGD(my_net.parameters(), lr=args.lr, momentum=0.9)
    # scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9)
    # torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1)
    
    if args.resume==1:
        print("continue last train")
        model_checkpoint = torch.load(pre_train_path,map_location=device)
    else:
        print("restart train")
        model_checkpoint = torch.load("./model/save/"+args.train_name+"/first_"+str(args.channels)+".pth",map_location=device)

    my_net.load_state_dict(model_checkpoint['model'])
    optimizer_my.load_state_dict(model_checkpoint['optimizer'])
    epoch = model_checkpoint['epoch']
    
    if torch.cuda.device_count() > 1:
        write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        my_net = nn.DataParallel(my_net)
    else:
        write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!")

#     my_net = torch.nn.DataParallel(my_net)
    my_net.to(device)
    
    ##########################################################################training
    
    if args.channels==1:
        write_log("start")
        max_error=np.inf
        for e in range(args.epochs):
            #train
            my_net.train()
            loss=0
            start=time.time()
            for batch, (pr,hr,_,_) in enumerate(train_dataloders):
                write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time()))

    #             start=time.time()
                pr,hr= prepare([pr,hr])
                optimizer_my.zero_grad()
                with torch.set_grad_enabled(True):
                    sr = my_net(pr)
                    print(pr.shape)
                    print(sr.shape)

                    running_loss =criterion(sr, hr)

                    running_loss.backward()
                    optimizer_my.step()
                loss+=running_loss #.copy()?
                if batch%10==0:
                    state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                    torch.save(state, "./model/save/temp01/last.pth")
                write_log("Train done,train time cost %f s,learning rate:%f, loss: %f"%(start-time.time(),optimizer_my.state_dict()['param_groups'][0]['lr'] ,running_loss.item()  ))
                start=time.time()

            #validation
            my_net.eval()
            start=time.time()
            with torch.no_grad():
                eval_psnr=0
                eval_ssim=0
    #             tqdm_val = tqdm(val_dataloders, ncols=80)
                for batch, (pr,hr,_,_) in enumerate(val_dataloders):
                    pr,hr = prepare([pr,hr])
                    sr = my_net(pr,dem)
                    val_loss=criterion(sr, hr)
                    for ssr,hhr in zip(sr,hr):
                        eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )
                        eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )

            write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
                      e,
                      time.time()-start,
                      optimizer_my.state_dict()['param_groups'][0]['lr'],
                      loss.item()/len(train_data),
                      val_loss
                 ))
    #         print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
    #                   e,
    #                   time.time()-start,
    #                   optimizer_my.state_dict()['param_groups'][0]['lr'],
    #                   loss.item()/len(train_data),
    #                   val_loss
    #              ))
            if running_loss<max_error:
                max_error=running_loss
        #         torch.save(net,train_loss"_"+str(e)+".pkl")
                if not os.path.exists("./model/save/"+args.train_name+"/"):
                    os.mkdir("./model/save/"+args.train_name+"/")
                write_log("saving")
                state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                torch.save(state, "./model/save/temp01/"+str(e)+".pth")
    #             torch.save(net,"./model/save/"+args.train_name+"/"+str(e)+".pkl")


    
    if args.channels==2:
        write_log("start")
        max_error=np.inf
        for e in range(args.epochs):
            #train
            my_net.train()
            loss=0
            start=time.time()
            for batch, (pr,dem,hr,_,_) in enumerate(train_dataloders):
                write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time()))
                start=time.time()
                pr,dem,hr= prepare([pr,dem,hr])

                optimizer_my.zero_grad()
                with torch.set_grad_enabled(True):
                    sr = my_net(pr,dem)
                    running_loss =criterion(sr, hr)

                    running_loss.backward()
                    optimizer_my.step()
                loss+=running_loss #.copy()?
                if batch%10==0:
                    state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                    torch.save(state, "./model/save/temp01/last_"+str(args.channels)+".pth")
                write_log("Train done,train time cost %f s,loss: %f"%(start-time.time(),running_loss.item()  ))
                start=time.time()

            #validation
            my_net.eval()
            start=time.time()
            with torch.no_grad():
                eval_psnr=0
                eval_ssim=0
    #             tqdm_val = tqdm(val_dataloders, ncols=80)
                for idx_img, (pr,dem,hr,_,_) in enumerate(val_dataloders):
                    pr,dem,hr = prepare([pr,dem,hr])
                    sr = my_net(pr,dem)
                    val_loss=criterion(sr, hr)
                    for ssr,hhr in zip(sr,hr):
                        eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )
                        eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )

            write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
                      e,
                      time.time()-start,
                      optimizer_my.state_dict()['param_groups'][0]['lr'],
                      loss.item()/len(train_data),
                      val_loss
                 ))
    #         print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
    #                   e,
    #                   time.time()-start,
    #                   optimizer_my.state_dict()['param_groups'][0]['lr'],
    #                   loss.item()/len(train_data),
    #                   val_loss
    #              ))
            if running_loss<max_error:
                max_error=running_loss
        #         torch.save(net,train_loss"_"+str(e)+".pkl")
                if not os.path.exists("./model/save/"+args.train_name+"/"):
                    os.mkdir("./model/save/"+args.train_name+"/")
                write_log("saving")
                state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                torch.save(state, "./model/save/temp01/"+str(e)+".pth")
    #             torch.save(net,"./model/save/"+args.train_name+"/"+str(e)+".pkl")
    
    
    if args.channels==3:
        write_log("start")
        max_error=np.inf
        for e in range(args.epochs):
            #train
            my_net.train()
            loss=0
            start=time.time()
            for batch, (pr,dem,tasmax,hr,_,_) in enumerate(train_dataloders):
                write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time()))
                start=time.time()
                pr,dem,tasmax,hr= prepare([pr,dem,tasmax,hr])

                optimizer_my.zero_grad()
                with torch.set_grad_enabled(True):
                    sr = my_net(pr,dem)
                    running_loss =criterion(sr, hr,tasmax=tasmax)

                    running_loss.backward()
                    optimizer_my.step()
                loss+=running_loss #.copy()?
                if batch%10==0:
                    state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                    torch.save(state, "./model/save/temp01/last.pth")
                write_log("Train done,train time cost %f s,loss: %f"%(start-time.time(),running_loss.item()  ))
                start=time.time()

            #validation
            my_net.eval()
            start=time.time()
            with torch.no_grad():
                eval_psnr=0
                eval_ssim=0
    #             tqdm_val = tqdm(val_dataloders, ncols=80)
                for idx_img, (pr,dem,psl,zg,tasmax,tasmin, hr,_,_) in enumerate(val_dataloders):
                    pr,dem,psl,zg,tasmax,tasmin, hr = prepare([pr,dem,psl,zg,tasmax,tasmin, hr])
                    sr = my_net(pr,dem,psl,zg,tasmax,tasmin)
                    val_loss=criterion(sr, hr)
                    for ssr,hhr in zip(sr,hr):
                        eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )
                        eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )

            write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
                      e,
                      time.time()-start,
                      optimizer_my.state_dict()['param_groups'][0]['lr'],
                      loss.item()/len(train_data),
                      val_loss
                 ))
    #         print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
    #                   e,
    #                   time.time()-start,
    #                   optimizer_my.state_dict()['param_groups'][0]['lr'],
    #                   loss.item()/len(train_data),
    #                   val_loss
    #              ))
            if running_loss<max_error:
                max_error=running_loss
        #         torch.save(net,train_loss"_"+str(e)+".pkl")
                if not os.path.exists("./model/save/"+args.train_name+"/"):
                    os.mkdir("./model/save/"+args.train_name+"/")
                write_log("saving")
                state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                torch.save(state, "./model/save/temp01/"+str(e)+".pth")
    #             torch.save(net,"./model/save/"+args.train_name+"/"+str(e)+".pkl")    
    
    
    
            
    else:
        write_log("start")
        max_error=np.inf
        for e in range(args.epochs):
            #train
            my_net.train()
            loss=0
            start=time.time()
            for batch, (pr,dem,psl,zg,tasmax,tasmin, hr,_,_) in enumerate(train_dataloders):
                write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time()))
                start=time.time()
                pr,dem,psl,zg,tasmax,tasmin, hr = prepare([pr,dem,psl,zg,tasmax,tasmin, hr])

                optimizer_my.zero_grad()
                with torch.set_grad_enabled(True):
                    sr = my_net(pr,dem,psl,zg,tasmax,tasmin)
                    running_loss =criterion(sr, hr)

                    running_loss.backward()
                    optimizer_my.step()
                loss+=running_loss #.copy()?
                if batch%10==0:
                    state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                    torch.save(state, "./model/save/temp01/last.pth")
                write_log("Train done,train time cost %f s,loss: %f"%(start-time.time(),running_loss.item()  ))
                start=time.time()

            #validation
            my_net.eval()
            start=time.time()
            with torch.no_grad():
                eval_psnr=0
                eval_ssim=0
    #             tqdm_val = tqdm(val_dataloders, ncols=80)
                for idx_img, (pr,dem,psl,zg,tasmax,tasmin, hr,_,_) in enumerate(val_dataloders):
                    pr,dem,psl,zg,tasmax,tasmin, hr = prepare([pr,dem,psl,zg,tasmax,tasmin, hr])
                    sr = my_net(pr,dem,psl,zg,tasmax,tasmin)
                    val_loss=criterion(sr, hr)
                    for ssr,hhr in zip(sr,hr):
                        eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )
                        eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )

            write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
                      e,
                      time.time()-start,
                      optimizer_my.state_dict()['param_groups'][0]['lr'],
                      loss.item()/len(train_data),
                      val_loss
                 ))
    #         print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
    #                   e,
    #                   time.time()-start,
    #                   optimizer_my.state_dict()['param_groups'][0]['lr'],
    #                   loss.item()/len(train_data),
    #                   val_loss
    #              ))
            if running_loss<max_error:
                max_error=running_loss
        #         torch.save(net,train_loss"_"+str(e)+".pkl")
                if not os.path.exists("./model/save/"+args.train_name+"/"):
                    os.mkdir("./model/save/"+args.train_name+"/")
                write_log("saving")
                state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                torch.save(state, "./model/save/temp01/"+str(e)+".pth")
        if args.precision == 'half': tensor = tensor.half()
        if args.precision == 'single': tensor = tensor.float()
        return tensor.to(device)

    return [_prepare(_l) for _l in l]


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = utility.checkpoint(args)
net = model.Model(args, checkpoint)
net.load("./model/RCAN_BIX4.pt",
         pre_train="./model/RCAN_BIX4.pt",
         resume=args.resume,
         cpu=args.cpu)
my_net = my_model.Modify_RCAN(net, args, checkpoint)
args.lr = 0.001
criterion = nn.L1Loss()
optimizer_my = optim.SGD(my_net.parameters(), lr=args.lr, momentum=0.9)
# scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9)
# torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1)

#     if torch.cuda.device_count() > 1:
#         write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!")
#         # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
#         net = nn.DataParallel(net)
#     else:
#         write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!")

# net.to(device)
예제 #3
0
def main():
    pre_train_path=args.continue_train
    

    init_date=date(1970, 1, 1)
    start_date=date(1990, 1, 2)
    end_date=date(2011,12,25)
#     end_date=date(2012,12,25) #if 929 is true we should substract 1 day    
    sys = platform.system()
    args.file_ACCESS_dir="../data/"
    args.file_BARRA_dir="../data/barra_aus/"
#     if sys == "Windows":
#         init_date=date(1970, 1, 1)
#         start_date=date(1990, 1, 2)
#         end_date=date(1990,12,15) #if 929 is true we should substract 1 day   
#         args.cpu=True
# #         args.file_ACCESS_dir="E:/climate/access-s1/"
# #         args.file_BARRA_dir="C:/Users/JIA059/barra/"
#         args.file_DEM_dir="../DEM/"
#     else:
#         args.file_ACCESS_dir_pr="/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/"
#         args.file_ACCESS_dir="/g/data/ub7/access-s1/hc/raw_model/atmos/"
#         # training_name="temp01"
#         args.file_BARRA_dir="/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/"

    args.channels=0
    if args.pr:
        args.channels+=1
    if args.zg:
        args.channels+=1
    if args.psl:
        args.channels+=1
    if args.tasmax:
        args.channels+=1
    if args.tasmin:
        args.channels+=1
    if args.dem:
        args.channels+=1
    leading_time=217
    args.leading_time_we_use=1
    args.ensemble=11
    pre_train_path="./model/prprpr/best.pth"
    
    ##
    def prepare( l, volatile=False):
        def _prepare(tensor):
            if args.precision == 'half': tensor = tensor.half()
            if args.precision == 'single': tensor = tensor.float()
            return tensor.to(device)

        return [_prepare(_l) for _l in l]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    checkpoint = utility.checkpoint(args)
    net = model.Model(args, checkpoint)
#     net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=True)
    if not args.prprpr:
        
        print("no prprprprprrpprprpprpprrp")
        net=my_model.Modify_RCAN(net,args,checkpoint)


    
    args.lr=0.00001
    criterion = nn.L1Loss()
#     criterion=nn.MSELoss()

    optimizer_my = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9)
    # scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9)
    
    if pre_train_path!=".":
        write_log("load last train from"+pre_train_path)
        model_checkpoint = torch.load(pre_train_path,map_location=device)
        net.load_state_dict(model_checkpoint['model'])
#         net.load(pre_train_path)
        optimizer_my.load_state_dict(model_checkpoint['optimizer'])
        epoch = model_checkpoint['epoch']
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9,last_epoch=epoch)
        print(scheduler.state_dict())

#demo input precipitation data(pr)

        def add_lat_lon_data(data,domain=[112.9, 154.00, -43.7425, -9.0],xarray=True):
            "data: is the something you want to add lat and lon, with first demenstion is lat,second dimention is lon,domain is DEM domain "
            new_lon=np.linspace(domain[0],domain[1],data.shape[1])
            new_lat=np.linspace(domain[2],domain[3],data.shape[0])
            if xarray:
                return xr.DataArray(data,coords=[new_lat,new_lon],dims=["lat","lon"])
            else:
                return data,new_lat,new_lon

        demo_date=date(1990,1,25)
        idx=0
        ensamble_demo="e01"
        file="../data/"
        pr=np.expand_dims(np.repeat(np.expand_dims(dpt.read_access_data(file,ensamble_demo,demo_date,idx),axis=0),3,axis=0),axis=0)
        print(pr.shape)

        pr=prepare([torch.tensor(pr)])

        hr=net(pr[0],0).cpu().detach().numpy()
        print(np.squeeze(hr[:,1]).shape)


        title="test \n date: "+(demo_date+timedelta(idx)).strftime("%Y%m%d")
        # prec_in=dpt.read_access_data(filename,idx=idx)*86400
        hr,lat,lon=add_lat_lon_data(np.squeeze(hr[:,1]),xarray=False)
        # print(hr)
        dpt.draw_aus(hr,lat,lon,title=title,save=True,path="test")
        # print(prec_in.shape[0],prec_in.shape[1])        
        
    
    # torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1)
    
#     if args.resume==1:
#         print("continue last train")
#         model_checkpoint = torch.load(pre_train_path,map_location=device)
#     else:
#         print("restart train")
#         model_checkpoint = torch.load("./model/save/"+args.train_name+"/first_"+str(args.channels)+".pth",map_location=device)

#     my_net.load_state_dict(model_checkpoint['model'])
#     optimizer_my.load_state_dict(model_checkpoint['optimizer'])
#     epoch = model_checkpoint['epoch']
    
    if torch.cuda.device_count() > 1:
        write_log("!!!!!!!!!!!!!Let's use"+str(torch.cuda.device_count())+"GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        net = nn.DataParallel(net,range(torch.cuda.device_count()))
    else:
        write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!")

#     my_net = torch.nn.DataParallel(my_net)
    net.to(device)