Ejemplo n.º 1
0
    def load(self,args):

        if args.model_dir != "":
            loadedparams = torch.load(args.model_dir,map_location=self.device)
            self.agent = agent.Agent(args,chkpoint=loadedparams)
        else:
            self.agent = agent.Agent(args)
        self.SRmodels = []
        self.SRoptimizers = []
        self.schedulers = []
        for i in range(args.action_space):

            #CREATE THE ARCH
            if args.model == 'ESRGAN':
                model = arch.RRDBNet(3,3,64,23,gc=32)
            elif args.model == 'RCAN':
                torch.manual_seed(args.seed)
                checkpoint = utility.checkpoint(args)
                if checkpoint.ok:
                    module = import_module('model.rcan')
                    model = module.make_model(args).to(self.device)
                    kwargs = {}
                else: print('error loading RCAN model. QUITING'); quit();

            #LOAD THE WEIGHTS
            if args.model_dir != "":
                model.load_state_dict(loadedparams["sisr"+str(i)])
                print('continuing training')
            elif args.random:
                #model.apply(init_zero)
                print('random init')
            elif args.model == 'ESRGAN':
                model.load_state_dict(torch.load(args.ESRGAN_PATH),strict=True)
            elif args.model == 'RCAN':
                model.load_state_dict(torch.load(args.pre_train,**kwargs),strict=True)

            self.SRmodels.append(model)
            self.SRmodels[-1].to(self.device)
            self.SRoptimizers.append(torch.optim.Adam(model.parameters(),lr=1e-4))
            self.schedulers.append(torch.optim.lr_scheduler.StepLR(self.SRoptimizers[-1],500,gamma=0.1))

        #INCREMENT SCHEDULES TO THE CORRECT LOCATION
        for i in range(args.step):
            [s.step() for s in self.schedulers]
Ejemplo n.º 2
0
    def load(self, args):
        self.SRmodels = []

        if not args.baseline:
            loadedparams = torch.load(args.model_dir, map_location=self.device)
            self.agent = Agent(args, chkpoint=loadedparams)
            self.agent.model.eval()

        for i in range(args.action_space):
            if args.model == 'ESRGAN':
                model = arch.RRDBNet(3, 3, 64, 23, gc=32,
                                     upsize=args.upsize)  #
            elif args.model == 'basic':
                model = arch.RRDBNet(3,
                                     3,
                                     32,
                                     args.d,
                                     gc=8,
                                     upsize=args.upsize)
            elif args.model == 'RCAN':
                torch.manual_seed(args.seed)
                checkpoint = utility.checkpoint(args)
                if checkpoint.ok:
                    module = import_module('model.' + args.model.lower())
                    model = module.make_model(args).to(self.device)
                    kwargs = {}
                else:
                    print('error loading RCAN model. QUITING')
                    quit()

            if args.baseline and args.model == 'ESRGAN':
                model.load_state_dict(torch.load(args.ESRGAN_PATH),
                                      strict=True)
            elif args.baseline and args.model == 'RCAN':
                print('rcan loaded')
                model.load_state_dict(torch.load(args.pre_train), strict=False)
            elif args.model == 'bicubic':
                return
            else:
                model.load_state_dict(loadedparams["sisr" + str(i)])

            self.SRmodels.append(model)
            self.SRmodels[-1].to(self.device)
            self.SRmodels[-1].eval()
Ejemplo n.º 3
0
def main(myargs=None):
    setup_args(args)
    checkpoint = utility.checkpoint(args, outdir=myargs.args.outdir)
    global model
    if args.data_test == ['video']:
        from videotester import VideoTester
        model = model.Model(args, checkpoint)
        t = VideoTester(args, model, checkpoint)
        t.test()
    else:
        if checkpoint.ok:
            loader = data.Data(args)
            _model = model.Model(args, checkpoint)
            _loss = loss.Loss(args, checkpoint) if not args.test_only else None
            t = Trainer(args, loader, _model, _loss, checkpoint)
            while not t.terminate():
                t.train()
                t.test()

            checkpoint.done()
Ejemplo n.º 4
0
Archivo: train.py Proyecto: yhu9/RCAN
 def load(self, args):
     if args.model == 'ESRGAN':
         model = arch.RRDBNet(3, 3, 64, 23, gc=32)
         model.load_state_dict(torch.load(self.SRMODEL_PATH), strict=True)
     elif args.model == 'random':
         model = arch.RRDBNet(3, 3, 64, 23, gc=32)
         model.apply(self.init_weights)
     elif args.model == 'RCAN':
         torch.manual_seed(args.seed)
         checkpoint = utility.checkpoint(args)
         if checkpoint.ok:
             module = import_module('model.' + args.model.lower())
             model = module.make_model(args).to(self.device)
             kwargs = {}
             model.load_state_dict(torch.load(args.pre_train, **kwargs),
                                   strict=False)
         else:
             print('error loading RCAN model. QUITING')
             quit()
     return model
Ejemplo n.º 5
0
args.cpu = False # 'store_true'
args.n_GPUs = 2
args.test_only = True
# saving and loading models
args.save_every = 100
args.save_models = True # saves all intermediate models
folder = "2019-03-12-10:04:18_1*SmoothL1"
model = "8000"
args.pre_train = args.root_dir + 'experiment/{}/model/model_{}.pt'.format(folder, model)
# file name to save, if '.' the name is date+time
if args.test_only==False:
    args.save = args.loss
else: args.save = "{}_Testing".format(folder)
args.save_results = True
loader = dataloader.StereoMSIDatasetLoader(args)
checkpoint = utility.checkpoint(args, loader)    
# make directory if does not exist        
test_dir = os.path.join(checkpoint.dir, 'test_results')
if not os.path.exists(test_dir):
    os.mkdir(test_dir)
test_model_dir = os.path.join(test_dir, 'model_{}/'.format(model))
os.mkdir(test_model_dir)
my_loss = loss.Loss(args, checkpoint) if not args.test_only else None
my_model = network.Model(args, checkpoint)
t = Trainer(args, loader, my_model, my_loss, checkpoint)    
if args.test_only==True:
    psnr = t.test_model(test_model_dir)
my_file = open(args.root_dir + "experiment/" + args.save + "/metrics.txt",'a')
my_file.writelines("PSNR: {}".format(psnr))
my_file.close()
Ejemplo n.º 6
0
Archivo: train.py Proyecto: yhu9/RCAN
    def __init__(self, args=args):

        #RANDOM MODEL INITIALIZATION FUNCTION
        def init_weights(m):
            if isinstance(m, torch.nn.Linear) or isinstance(
                    m, torch.nn.Conv2d):
                torch.nn.init.xavier_uniform_(m.weight.data)

        #INITIALIZE VARIABLES
        self.SR_COUNT = args.action_space
        SRMODEL_PATH = args.srmodel_path
        self.batch_size = args.batch_size
        self.TRAINING_LRPATH = glob.glob(
            os.path.join(args.training_lrpath, "*"))
        self.TRAINING_HRPATH = glob.glob(
            os.path.join(args.training_hrpath, "*"))
        self.TRAINING_LRPATH.sort()
        self.TRAINING_HRPATH.sort()
        self.PATCH_SIZE = args.patchsize
        self.patchinfo_dir = args.patchinfo
        self.TESTING_PATH = glob.glob(os.path.join(args.testing_path, "*"))
        self.LR = args.learning_rate
        self.UPSIZE = args.upsize
        self.step = 0
        self.name = args.name
        if args.name != 'none':
            self.logger = logger.Logger(
                args.name)  #create our logger for tensorboard in log directory
        else:
            self.logger = None
        self.device = torch.device(args.device)  #determine cpu/gpu

        #DEFAULT START OR START ON PREVIOUSLY TRAINED EPOCH
        if args.model_dir != "":
            self.load(args)
            print('continue training for model: ' + args.model_dir)
        else:
            self.SRmodels = []
            self.SRoptimizers = []
            self.schedulers = []
            #LOAD A COPY OF THE MODEL N TIMES
            for i in range(self.SR_COUNT):
                if args.model == 'ESRGAN':
                    model = arch.RRDBNet(3, 3, 64, 23, gc=32)
                    model.load_state_dict(torch.load(args.ESRGAN_PATH),
                                          strict=True)
                    print('ESRGAN loaded')
                elif args.model == 'random':
                    model = arch.RRDBNet(3, 3, 64, 23, gc=32)
                    model.apply(init_weights)
                    print('Model RRDB Loaded with random weights...')
                elif args.model == 'RCAN':
                    torch.manual_seed(args.seed)
                    checkpoint = utility.checkpoint(args)
                    if checkpoint.ok:
                        module = import_module('model.' + args.model.lower())
                        model = module.make_model(args).to(self.device)
                        kwargs = {}
                        model.load_state_dict(torch.load(
                            args.pre_train, **kwargs),
                                              strict=False)
                    else:
                        print('error')
                self.SRmodels.append(model)
                self.SRmodels[-1].to(self.device)
                self.SRoptimizers.append(
                    torch.optim.Adam(model.parameters(), lr=1e-4))
                self.schedulers.append(
                    torch.optim.lr_scheduler.StepLR(self.SRoptimizers[-1],
                                                    10000,
                                                    gamma=0.1))

            #self.patchinfo = np.load(self.patchinfo_dir)
            self.agent = agent.Agent(args)
Ejemplo n.º 7
0
def main():
    init_date = date(1970, 1, 1)
    start_date = date(1990, 1, 2)
    end_date = date(2012, 12, 25)  #if 929 is true we should substract 1 day
    sys = platform.system()

    if sys == "Windows":
        init_date = date(1970, 1, 1)
        start_date = date(1990, 1, 2)
        end_date = date(1990, 12,
                        15)  #if 929 is true we should substract 1 day
        #         args.file_ACCESS_dir="E:/climate/access-s1/"
        #         args.file_BARRA_dir="C:/Users/JIA059/barra/"
        args.file_DEM_dir = "../DEM/"

        args.file_ACCESS_dir = "H:/climate/access-s1/"
        args.file_BARRA_dir = "D:/dataset/accum_prcp/"
    else:
        args.file_ACCESS_dir_pr = "/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/"
        args.file_ACCESS_dir = "/g/data/ub7/access-s1/hc/raw_model/atmos/"
        # training_name="temp01"
        args.file_BARRA_dir = "/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/"

    args.channels = 0
    if args.pr:
        args.channels += 1
    if args.zg:
        args.channels += 1
    if args.psl:
        args.channels += 1
    if args.tasmax:
        args.channels += 1
    if args.tasmin:
        args.channels += 1
    print(args.dem)
    if args.dem:
        args.channels += 1
    access_rgb_mean = 2.9067910245780248e-05 * 86400

    leading_time = 217
    args.leading_time_we_use = 7
    args.ensemble = 2

    print(access_rgb_mean)

    print("training statistics:")
    print("  ------------------------------")
    print("  trainning name  |  %s" % args.train_name)
    print("  ------------------------------")
    print("  num of channels | %5d" % args.channels)
    print("  ------------------------------")
    print("  num of threads  | %5d" % args.n_threads)
    print("  ------------------------------")
    print("  batch_size     | %5d" % args.batch_size)
    print("  ------------------------------")
    print("  using cpu only? | %5d" % args.cpu)

    ############################################################################################

    train_transforms = transforms.Compose([
        #     transforms.Resize(IMG_SIZE),
        #     transforms.RandomResizedCrop(IMG_SIZE),
        #     transforms.RandomHorizontalFlip(),
        #     transforms.RandomRotation(30),
        transforms.ToTensor()
        #     transforms.Normalize(IMG_MEAN, IMG_STD)
    ])

    data_set = ACCESS_BARRA_v3(start_date,
                               end_date,
                               transform=train_transforms,
                               args=args)
    train_data, val_data = random_split(
        data_set,
        [int(len(data_set) * 0.8),
         len(data_set) - int(len(data_set) * 0.8)])

    print("Dataset statistics:")
    print("  ------------------------------")
    print("  total | %5d" % len(data_set))
    print("  ------------------------------")
    print("  train | %5d" % len(train_data))
    print("  ------------------------------")
    print("  val   | %5d" % len(val_data))

    ###################################################################################set a the dataLoader
    train_dataloders = DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.n_threads)
    val_dataloders = DataLoader(val_data,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.n_threads)

    ##
    def prepare(l, volatile=False):
        device = torch.device('cpu' if args.cpu else 'cuda')

        def _prepare(tensor):
            if args.precision == 'half': tensor = tensor.half()
            return tensor.to(device)

        return [_prepare(_l) for _l in l]

    checkpoint = utility.checkpoint(args)
    net = model.Model(args, checkpoint).double()
    args.lr = 0.001
    criterion = nn.L1Loss()
    optimizer_my = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9)
    # scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9)
    # torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1)

    ##########################################################################training
    #training

    max_error = np.inf
    for e in range(args.epochs):
        #train
        net.train()
        loss = 0
        start = time.time()
        for batch, (lr, hr, _, _) in enumerate(train_dataloders):
            print("Train for batch %d,data loading time cost %f s" %
                  (batch, start - time.time()))
            start = time.time()
            lr, hr = prepare([lr, hr])

            optimizer_my.zero_grad()
            with torch.set_grad_enabled(True):
                sr = net(lr, 0)
                running_loss = criterion(sr, hr)
                loss += running_loss  #.copy()?
            running_loss.backward()
            optimizer_my.step()
            print("Train done,train time cost %f s" % (start - time.time()))
            start = time.time()

        #validation
        net.eval()
        start = time.time()
        with torch.no_grad():
            eval_psnr = 0
            eval_ssim = 0
            tqdm_val = tqdm(val_dataloders, ncols=80)
            for idx_img, (lr, hr, _, _) in enumerate(tqdm_val):
                lr, hr = prepare([lr, hr])
                sr = net(lr, 0)
                val_loss = criterion(sr, hr)
                for ssr, hhr in zip(sr, hr):
                    eval_psnr += compare_psnr(
                        ssr[0].cpu().numpy(),
                        hhr[0].cpu().numpy(),
                        data_range=(hhr[0].cpu().max() -
                                    hhr[0].cpu().min()).item())
                    eval_ssim += compare_ssim(
                        ssr[0].cpu().numpy(),
                        hhr[0].cpu().numpy(),
                        data_range=(hhr[0].cpu().max() -
                                    hhr[0].cpu().min()).item())
        print(
            "epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "
            % (e, time.time() - start,
               optimizer_my.state_dict()['param_groups'][0]['lr'],
               loss.item() / len(train_data), val_loss))
        if running_loss < max_error:
            max_error = running_loss
            #         torch.save(net,train_loss"_"+str(e)+".pkl")
            if not os.path.exists("./model/save/" + args.train_name + "/"):
                os.mkdir("./model/save/" + args.train_name + "/")
            torch.save(
                net, "./model/save/" + args.train_name + "/" + str(e) + ".pkl")
Ejemplo n.º 8
0
def main():
    pre_train_path=args.continue_train
    

    init_date=date(1970, 1, 1)
    start_date=date(1990, 1, 2)
    end_date=date(2011,12,25)
#     end_date=date(2012,12,25) #if 929 is true we should substract 1 day    
    sys = platform.system()
    args.file_ACCESS_dir="../data/"
    args.file_BARRA_dir="../data/barra_aus/"
#     if sys == "Windows":
#         init_date=date(1970, 1, 1)
#         start_date=date(1990, 1, 2)
#         end_date=date(1990,12,15) #if 929 is true we should substract 1 day   
#         args.cpu=True
# #         args.file_ACCESS_dir="E:/climate/access-s1/"
# #         args.file_BARRA_dir="C:/Users/JIA059/barra/"
#         args.file_DEM_dir="../DEM/"
#     else:
#         args.file_ACCESS_dir_pr="/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/"
#         args.file_ACCESS_dir="/g/data/ub7/access-s1/hc/raw_model/atmos/"
#         # training_name="temp01"
#         args.file_BARRA_dir="/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/"

    args.channels=0
    if args.pr:
        args.channels+=1
    if args.zg:
        args.channels+=1
    if args.psl:
        args.channels+=1
    if args.tasmax:
        args.channels+=1
    if args.tasmin:
        args.channels+=1
    if args.dem:
        args.channels+=1
    leading_time=217
    args.leading_time_we_use=1
    args.ensemble=11
    pre_train_path="./model/prprpr/best.pth"
    
    ##
    def prepare( l, volatile=False):
        def _prepare(tensor):
            if args.precision == 'half': tensor = tensor.half()
            if args.precision == 'single': tensor = tensor.float()
            return tensor.to(device)

        return [_prepare(_l) for _l in l]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    checkpoint = utility.checkpoint(args)
    net = model.Model(args, checkpoint)
#     net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=True)
    if not args.prprpr:
        
        print("no prprprprprrpprprpprpprrp")
        net=my_model.Modify_RCAN(net,args,checkpoint)


    
    args.lr=0.00001
    criterion = nn.L1Loss()
#     criterion=nn.MSELoss()

    optimizer_my = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9)
    # scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9)
    
    if pre_train_path!=".":
        write_log("load last train from"+pre_train_path)
        model_checkpoint = torch.load(pre_train_path,map_location=device)
        net.load_state_dict(model_checkpoint['model'])
#         net.load(pre_train_path)
        optimizer_my.load_state_dict(model_checkpoint['optimizer'])
        epoch = model_checkpoint['epoch']
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9,last_epoch=epoch)
        print(scheduler.state_dict())

#demo input precipitation data(pr)

        def add_lat_lon_data(data,domain=[112.9, 154.00, -43.7425, -9.0],xarray=True):
            "data: is the something you want to add lat and lon, with first demenstion is lat,second dimention is lon,domain is DEM domain "
            new_lon=np.linspace(domain[0],domain[1],data.shape[1])
            new_lat=np.linspace(domain[2],domain[3],data.shape[0])
            if xarray:
                return xr.DataArray(data,coords=[new_lat,new_lon],dims=["lat","lon"])
            else:
                return data,new_lat,new_lon

        demo_date=date(1990,1,25)
        idx=0
        ensamble_demo="e01"
        file="../data/"
        pr=np.expand_dims(np.repeat(np.expand_dims(dpt.read_access_data(file,ensamble_demo,demo_date,idx),axis=0),3,axis=0),axis=0)
        print(pr.shape)

        pr=prepare([torch.tensor(pr)])

        hr=net(pr[0],0).cpu().detach().numpy()
        print(np.squeeze(hr[:,1]).shape)


        title="test \n date: "+(demo_date+timedelta(idx)).strftime("%Y%m%d")
        # prec_in=dpt.read_access_data(filename,idx=idx)*86400
        hr,lat,lon=add_lat_lon_data(np.squeeze(hr[:,1]),xarray=False)
        # print(hr)
        dpt.draw_aus(hr,lat,lon,title=title,save=True,path="test")
        # print(prec_in.shape[0],prec_in.shape[1])        
        
    
    # torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1)
    
#     if args.resume==1:
#         print("continue last train")
#         model_checkpoint = torch.load(pre_train_path,map_location=device)
#     else:
#         print("restart train")
#         model_checkpoint = torch.load("./model/save/"+args.train_name+"/first_"+str(args.channels)+".pth",map_location=device)

#     my_net.load_state_dict(model_checkpoint['model'])
#     optimizer_my.load_state_dict(model_checkpoint['optimizer'])
#     epoch = model_checkpoint['epoch']
    
    if torch.cuda.device_count() > 1:
        write_log("!!!!!!!!!!!!!Let's use"+str(torch.cuda.device_count())+"GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        net = nn.DataParallel(net,range(torch.cuda.device_count()))
    else:
        write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!")

#     my_net = torch.nn.DataParallel(my_net)
    net.to(device)
Ejemplo n.º 9
0
import torch
import utility
import data
import model
import loss
from option import args
from trainer import Trainer

if __name__ == '__main__':
    torch.manual_seed(args.seed)
    checkpoint = utility.checkpoint(
        args)  ## setting the log and the train information
    if checkpoint.ok:
        loader = data.Data(args)  ## data loader
        model = model.Model(args, checkpoint)
        loss = loss.Loss(args, checkpoint) if not args.test_only else None
        t = Trainer(args, loader, model, loss, checkpoint)
        while not t.terminate():
            t.train()

        checkpoint.done()
Ejemplo n.º 10
0
import torch

import utility
import data
import model
import loss
from option import args
from trainer import Trainer

torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)  # load model from experiment/xxx


def main():
    global model
    if args.data_test == ['video']:
        from videotester import VideoTester
        model = model.Model(args, checkpoint)
        t = VideoTester(args, model, checkpoint)
        t.test()
    else:
        if checkpoint.ok:
            loader = data.Data(args)
            _model = model.Model(args, checkpoint)
            _loss = loss.Loss(args, checkpoint) if not args.test_only else None
            t = Trainer(args, loader, _model, _loss, checkpoint)
            while not t.terminate():
                t.train()
                t.test()

            checkpoint.done()
Ejemplo n.º 11
0
import torch

import utility

# 三个文件夹
import data
import model
import loss
from option import args
from trainer import Trainer

torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)  # 保存模型的类

if checkpoint.ok:
    # 初始化数据加载类,模型类,loss类,训练类
    loader = data.Data(args)  # Data类位于__init__.py中
    model = model.Model(args, checkpoint)  # Model类位于__init__.py中
    loss = loss.Loss(
        args,
        checkpoint) if not args.test_only else None  # # Loss类位于__init__.py中
    t = Trainer(args, loader, model, loss, checkpoint)

    while not t.terminate():
        t.train()
        t.test()

    checkpoint.done()
Ejemplo n.º 12
0
Archivo: train.py Proyecto: yhu9/RCAN
    def load(self, args):

        if args.model_dir != "":
            loadedparams = torch.load(args.model_dir, map_location=self.device)
            #self.agent = agent.Agent(args,chkpoint=loadedparams)
            self.agent = agent.Agent(args)
        else:
            self.agent = agent.Agent(args)
        self.SRmodels = []
        self.SRoptimizers = []
        self.schedulers = []
        for i in range(args.action_space):

            #CREATE THE ARCH
            if args.model == 'ESRGAN':
                model = arch.RRDBNet(3, 3, 64, 23, gc=32)
            if args.model == 'basic':
                model = arch.RRDBNet(3,
                                     3,
                                     32,
                                     args.d,
                                     gc=8,
                                     upsize=args.upsize)
            elif args.model == 'RCAN':
                torch.manual_seed(args.seed)
                checkpoint = utility.checkpoint(args)
                if checkpoint.ok:
                    module = import_module('model.rcan')
                    model = module.make_model(args).to(self.device)
                    kwargs = {}
                else:
                    print('error loading RCAN model. QUITING')
                    quit()

            #LOAD THE WEIGHTS
            if args.model_dir != "":
                model.load_state_dict(loadedparams["sisr" + str(i)])
                print('continuing training')
            elif args.random:
                print('random init')
            elif args.model == 'ESRGAN':
                model.load_state_dict(torch.load(args.ESRGAN_PATH),
                                      strict=True)
            elif args.model == 'RCAN':
                print('RCAN loaded!')
                model.load_state_dict(torch.load(args.pre_train, **kwargs),
                                      strict=True)
            elif args.model == 'basic':
                print('loading basic model')
                if args.d == 1:
                    model.load_state_dict(torch.load(args.basicpath_d1),
                                          strict=True)
                elif args.d == 2:
                    model.load_state_dict(torch.load(args.basicpath_d2),
                                          strict=True)
                elif args.d == 4:
                    model.load_state_dict(torch.load(args.basicpath_d4),
                                          strict=True)
                elif args.d == 8:
                    model.load_state_dict(torch.load(args.basicpath_d8),
                                          strict=True)
                else:
                    print(
                        'no pretrained model available. Random initialization of basic block'
                    )

            self.SRmodels.append(model)
            self.SRmodels[-1].to(self.device)

            self.SRoptimizers.append(
                torch.optim.Adam(model.parameters(), lr=1e-5))
            scheduler = torch.optim.lr_scheduler.StepLR(self.SRoptimizers[-1],
                                                        200,
                                                        gamma=0.8)

            self.schedulers.append(scheduler)

        #INCREMENT SCHEDULES TO THE CORRECT LOCATION
        for i in range(args.step):
            [s.step() for s in self.schedulers]
Ejemplo n.º 13
0
def main():
    ck = util.checkpoint(args)
    seed = args.seed
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ck.write_log(str(args))
    # t = str(int(time.time()))
    # t = args.save_name
    # os.mkdir('./{}'.format(t))
    # (ch_out, ch_in, k, k, stride, padding)
    config = [('conv2d', [32, 16, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('relu', [True]),
              ('+1', [True]), ('conv2d', [3, 32, 3, 3, 1, 1])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)
    # (Dataset) calculate the number of trainable tensors
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    ck.write_log(str(maml))
    ck.write_log('Total trainable tensors: {}'.format(num))

    # (Dataset) batchsz here means total episode number
    DL_MSI = dl.StereoMSIDatasetLoader(args)
    db = DL_MSI.train_loader
    dv = DL_MSI.valid_loader

    psnr = []
    l1_loss = []
    psnr_valid = []
    for epoch, (spt_ms, spt_rgb, qry_ms, qry_rgb) in enumerate(db):

        if epoch // args.epoch: break
        spt_ms, spt_rgb, qry_ms, qry_rgb = (spt_ms.to(device),
                                            spt_rgb.to(device),
                                            qry_ms.to(device),
                                            qry_rgb.to(device))

        # optimization is carried out inside meta_learner class, maml.
        accs, train_loss = maml(spt_ms, spt_rgb, qry_ms, qry_rgb, epoch)
        maml.scheduler.step()

        if epoch % args.print_every == 0:
            log_epoch = 'epoch: {} \ttraining acc: {}'.format(epoch, accs)
            ck.write_log(log_epoch)
            psnr.append(accs)
            l1_loss.append(train_loss)
            ck.plot_loss(psnr, l1_loss, epoch, args.print_every)
            if epoch % args.save_every == 0:
                with torch.no_grad():
                    ck.save(maml.net, maml.meta_optim, epoch)
                    eval_psnr = 0  # psnr loss
                    for idx, (valid_ms, valid_rgb) in enumerate(dv):
                        #print('idx', idx)
                        valid_ms, valid_rgb = prepare([valid_ms, valid_rgb])
                        sr_rgb = maml.net(valid_ms)
                        sr_rgb = torch.clamp(sr_rgb, 0, 1)
                        eval_psnr += errors.find_psnr(valid_rgb, sr_rgb)
                        ############## plot PSNR here you idiot! ###########
                    psnr_valid.append(eval_psnr / 25)
                    ck.plot_psnr(psnr_valid, epoch, args.save_every)
                    ck.write_log('Max PSNR is: {}'.format(max(psnr_valid)))
                    imsave(
                        './{}/validation/img_{}.png'.format(ck.dir, epoch),
                        np.uint8(sr_rgb[0, :, :, :].permute(
                            1, 2, 0).cpu().detach().numpy() * 255))
    ck.done()
def main():
    
#     pre_train_path="./model/save/temp01/"+0+".pth"

    
    
    init_date=date(1970, 1, 1)
    start_date=date(1990, 1, 2)
    end_date=date(1990,12,25)
#     end_date=date(2012,12,25) #if 929 is true we should substract 1 day    
    sys = platform.system()
    
    if sys == "Windows":
        init_date=date(1970, 1, 1)
        start_date=date(1990, 1, 2)
        end_date=date(1990,12,15) #if 929 is true we should substract 1 day   
        args.file_ACCESS_dir="H:/climate/access-s1/" 
        args.file_BARRA_dir="D:/dataset/accum_prcp/"
#         args.file_ACCESS_dir="E:/climate/access-s1/"
#         args.file_BARRA_dir="C:/Users/JIA059/barra/"
        args.file_DEM_dir="../DEM/"
    else:
        args.file_ACCESS_dir_pr="/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/"
        args.file_ACCESS_dir="/g/data/ub7/access-s1/hc/raw_model/atmos/"
        # training_name="temp01"
        args.file_BARRA_dir="/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/"

    args.channels=0
    if args.pr:
        args.channels+=1
    if args.zg:
        args.channels+=1
    if args.psl:
        args.channels+=1
    if args.tasmax:
        args.channels+=1
    if args.tasmin:
        args.channels+=1
    if args.dem:
        args.channels+=1
    access_rgb_mean= 2.9067910245780248e-05*86400
    pre_train_path="./model/save/"+args.train_name+"/last_"+str(args.channels)+".pth"
    leading_time=217
    args.leading_time_we_use=1
    args.ensemble=1


    print(access_rgb_mean)

    print("training statistics:")
    print("  ------------------------------")
    print("  trainning name  |  %s"%args.train_name)
    print("  ------------------------------")
    print("  num of channels | %5d"%args.channels)
    print("  ------------------------------")
    print("  num of threads  | %5d"%args.n_threads)
    print("  ------------------------------")
    print("  batch_size     | %5d"%args.batch_size)
    print("  ------------------------------")
    print("  using cpu only? | %5d"%args.cpu)

    ############################################################################################

    train_transforms = transforms.Compose([
    #     transforms.Resize(IMG_SIZE),
    #     transforms.RandomResizedCrop(IMG_SIZE),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomRotation(30),
        transforms.ToTensor()
    #     transforms.Normalize(IMG_MEAN, IMG_STD)
    ])

    data_set=ACCESS_BARRA_v4(start_date,end_date,transform=train_transforms,args=args)
    train_data,val_data=random_split(data_set,[int(len(data_set)*0.8),len(data_set)-int(len(data_set)*0.8)])


    print("Dataset statistics:")
    print("  ------------------------------")
    print("  total | %5d"%len(data_set))
    print("  ------------------------------")
    print("  train | %5d"%len(train_data))
    print("  ------------------------------")
    print("  val   | %5d"%len(val_data))

    ###################################################################################set a the dataLoader
    train_dataloders =DataLoader(train_data,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                num_workers=args.n_threads)
    val_dataloders =DataLoader(val_data,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                              num_workers=args.n_threads)
    ##
    def prepare( l, volatile=False):
        def _prepare(tensor):
            if args.precision == 'half': tensor = tensor.half()
            if args.precision == 'single': tensor = tensor.float()
            return tensor.to(device)

        return [_prepare(_l) for _l in l]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    checkpoint = utility.checkpoint(args)
    net = model.Model(args, checkpoint)
#     net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=True)
    my_net=my_model.Modify_RCAN(net,args,checkpoint)

#     net.load("./model/RCAN_BIX4.pt", pre_train="./model/RCAN_BIX4.pt", resume=args.resume, cpu=args.cpu)
    
    args.lr=0.001
    criterion = nn.L1Loss()
    optimizer_my = optim.SGD(my_net.parameters(), lr=args.lr, momentum=0.9)
    # scheduler = optim.lr_scheduler.StepLR(optimizer_my, step_size=7, gamma=0.1)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer_my, gamma=0.9)
    # torch.optim.lr_scheduler.MultiStepLR(optimizer_my, milestones=[20,80], gamma=0.1)
    
    if args.resume==1:
        print("continue last train")
        model_checkpoint = torch.load(pre_train_path,map_location=device)
    else:
        print("restart train")
        model_checkpoint = torch.load("./model/save/"+args.train_name+"/first_"+str(args.channels)+".pth",map_location=device)

    my_net.load_state_dict(model_checkpoint['model'])
    optimizer_my.load_state_dict(model_checkpoint['optimizer'])
    epoch = model_checkpoint['epoch']
    
    if torch.cuda.device_count() > 1:
        write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        my_net = nn.DataParallel(my_net)
    else:
        write_log("Let's use"+str(torch.cuda.device_count())+"GPUs!")

#     my_net = torch.nn.DataParallel(my_net)
    my_net.to(device)
    
    ##########################################################################training
    
    if args.channels==1:
        write_log("start")
        max_error=np.inf
        for e in range(args.epochs):
            #train
            my_net.train()
            loss=0
            start=time.time()
            for batch, (pr,hr,_,_) in enumerate(train_dataloders):
                write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time()))

    #             start=time.time()
                pr,hr= prepare([pr,hr])
                optimizer_my.zero_grad()
                with torch.set_grad_enabled(True):
                    sr = my_net(pr)
                    print(pr.shape)
                    print(sr.shape)

                    running_loss =criterion(sr, hr)

                    running_loss.backward()
                    optimizer_my.step()
                loss+=running_loss #.copy()?
                if batch%10==0:
                    state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                    torch.save(state, "./model/save/temp01/last.pth")
                write_log("Train done,train time cost %f s,learning rate:%f, loss: %f"%(start-time.time(),optimizer_my.state_dict()['param_groups'][0]['lr'] ,running_loss.item()  ))
                start=time.time()

            #validation
            my_net.eval()
            start=time.time()
            with torch.no_grad():
                eval_psnr=0
                eval_ssim=0
    #             tqdm_val = tqdm(val_dataloders, ncols=80)
                for batch, (pr,hr,_,_) in enumerate(val_dataloders):
                    pr,hr = prepare([pr,hr])
                    sr = my_net(pr,dem)
                    val_loss=criterion(sr, hr)
                    for ssr,hhr in zip(sr,hr):
                        eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )
                        eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )

            write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
                      e,
                      time.time()-start,
                      optimizer_my.state_dict()['param_groups'][0]['lr'],
                      loss.item()/len(train_data),
                      val_loss
                 ))
    #         print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
    #                   e,
    #                   time.time()-start,
    #                   optimizer_my.state_dict()['param_groups'][0]['lr'],
    #                   loss.item()/len(train_data),
    #                   val_loss
    #              ))
            if running_loss<max_error:
                max_error=running_loss
        #         torch.save(net,train_loss"_"+str(e)+".pkl")
                if not os.path.exists("./model/save/"+args.train_name+"/"):
                    os.mkdir("./model/save/"+args.train_name+"/")
                write_log("saving")
                state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                torch.save(state, "./model/save/temp01/"+str(e)+".pth")
    #             torch.save(net,"./model/save/"+args.train_name+"/"+str(e)+".pkl")


    
    if args.channels==2:
        write_log("start")
        max_error=np.inf
        for e in range(args.epochs):
            #train
            my_net.train()
            loss=0
            start=time.time()
            for batch, (pr,dem,hr,_,_) in enumerate(train_dataloders):
                write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time()))
                start=time.time()
                pr,dem,hr= prepare([pr,dem,hr])

                optimizer_my.zero_grad()
                with torch.set_grad_enabled(True):
                    sr = my_net(pr,dem)
                    running_loss =criterion(sr, hr)

                    running_loss.backward()
                    optimizer_my.step()
                loss+=running_loss #.copy()?
                if batch%10==0:
                    state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                    torch.save(state, "./model/save/temp01/last_"+str(args.channels)+".pth")
                write_log("Train done,train time cost %f s,loss: %f"%(start-time.time(),running_loss.item()  ))
                start=time.time()

            #validation
            my_net.eval()
            start=time.time()
            with torch.no_grad():
                eval_psnr=0
                eval_ssim=0
    #             tqdm_val = tqdm(val_dataloders, ncols=80)
                for idx_img, (pr,dem,hr,_,_) in enumerate(val_dataloders):
                    pr,dem,hr = prepare([pr,dem,hr])
                    sr = my_net(pr,dem)
                    val_loss=criterion(sr, hr)
                    for ssr,hhr in zip(sr,hr):
                        eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )
                        eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )

            write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
                      e,
                      time.time()-start,
                      optimizer_my.state_dict()['param_groups'][0]['lr'],
                      loss.item()/len(train_data),
                      val_loss
                 ))
    #         print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
    #                   e,
    #                   time.time()-start,
    #                   optimizer_my.state_dict()['param_groups'][0]['lr'],
    #                   loss.item()/len(train_data),
    #                   val_loss
    #              ))
            if running_loss<max_error:
                max_error=running_loss
        #         torch.save(net,train_loss"_"+str(e)+".pkl")
                if not os.path.exists("./model/save/"+args.train_name+"/"):
                    os.mkdir("./model/save/"+args.train_name+"/")
                write_log("saving")
                state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                torch.save(state, "./model/save/temp01/"+str(e)+".pth")
    #             torch.save(net,"./model/save/"+args.train_name+"/"+str(e)+".pkl")
    
    
    if args.channels==3:
        write_log("start")
        max_error=np.inf
        for e in range(args.epochs):
            #train
            my_net.train()
            loss=0
            start=time.time()
            for batch, (pr,dem,tasmax,hr,_,_) in enumerate(train_dataloders):
                write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time()))
                start=time.time()
                pr,dem,tasmax,hr= prepare([pr,dem,tasmax,hr])

                optimizer_my.zero_grad()
                with torch.set_grad_enabled(True):
                    sr = my_net(pr,dem)
                    running_loss =criterion(sr, hr,tasmax=tasmax)

                    running_loss.backward()
                    optimizer_my.step()
                loss+=running_loss #.copy()?
                if batch%10==0:
                    state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                    torch.save(state, "./model/save/temp01/last.pth")
                write_log("Train done,train time cost %f s,loss: %f"%(start-time.time(),running_loss.item()  ))
                start=time.time()

            #validation
            my_net.eval()
            start=time.time()
            with torch.no_grad():
                eval_psnr=0
                eval_ssim=0
    #             tqdm_val = tqdm(val_dataloders, ncols=80)
                for idx_img, (pr,dem,psl,zg,tasmax,tasmin, hr,_,_) in enumerate(val_dataloders):
                    pr,dem,psl,zg,tasmax,tasmin, hr = prepare([pr,dem,psl,zg,tasmax,tasmin, hr])
                    sr = my_net(pr,dem,psl,zg,tasmax,tasmin)
                    val_loss=criterion(sr, hr)
                    for ssr,hhr in zip(sr,hr):
                        eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )
                        eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )

            write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
                      e,
                      time.time()-start,
                      optimizer_my.state_dict()['param_groups'][0]['lr'],
                      loss.item()/len(train_data),
                      val_loss
                 ))
    #         print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
    #                   e,
    #                   time.time()-start,
    #                   optimizer_my.state_dict()['param_groups'][0]['lr'],
    #                   loss.item()/len(train_data),
    #                   val_loss
    #              ))
            if running_loss<max_error:
                max_error=running_loss
        #         torch.save(net,train_loss"_"+str(e)+".pkl")
                if not os.path.exists("./model/save/"+args.train_name+"/"):
                    os.mkdir("./model/save/"+args.train_name+"/")
                write_log("saving")
                state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                torch.save(state, "./model/save/temp01/"+str(e)+".pth")
    #             torch.save(net,"./model/save/"+args.train_name+"/"+str(e)+".pkl")    
    
    
    
            
    else:
        write_log("start")
        max_error=np.inf
        for e in range(args.epochs):
            #train
            my_net.train()
            loss=0
            start=time.time()
            for batch, (pr,dem,psl,zg,tasmax,tasmin, hr,_,_) in enumerate(train_dataloders):
                write_log("Train for batch %d,data loading time cost %f s"%(batch,start-time.time()))
                start=time.time()
                pr,dem,psl,zg,tasmax,tasmin, hr = prepare([pr,dem,psl,zg,tasmax,tasmin, hr])

                optimizer_my.zero_grad()
                with torch.set_grad_enabled(True):
                    sr = my_net(pr,dem,psl,zg,tasmax,tasmin)
                    running_loss =criterion(sr, hr)

                    running_loss.backward()
                    optimizer_my.step()
                loss+=running_loss #.copy()?
                if batch%10==0:
                    state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                    torch.save(state, "./model/save/temp01/last.pth")
                write_log("Train done,train time cost %f s,loss: %f"%(start-time.time(),running_loss.item()  ))
                start=time.time()

            #validation
            my_net.eval()
            start=time.time()
            with torch.no_grad():
                eval_psnr=0
                eval_ssim=0
    #             tqdm_val = tqdm(val_dataloders, ncols=80)
                for idx_img, (pr,dem,psl,zg,tasmax,tasmin, hr,_,_) in enumerate(val_dataloders):
                    pr,dem,psl,zg,tasmax,tasmin, hr = prepare([pr,dem,psl,zg,tasmax,tasmin, hr])
                    sr = my_net(pr,dem,psl,zg,tasmax,tasmin)
                    val_loss=criterion(sr, hr)
                    for ssr,hhr in zip(sr,hr):
                        eval_psnr+=compare_psnr(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )
                        eval_ssim+=compare_ssim(ssr[0].cpu().numpy(),hhr[0].cpu().numpy(),data_range=(hhr[0].cpu().max()-hhr[0].cpu().min()).item() )

            write_log("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
                      e,
                      time.time()-start,
                      optimizer_my.state_dict()['param_groups'][0]['lr'],
                      loss.item()/len(train_data),
                      val_loss
                 ))
    #         print("epoche: %d,time cost %f s, lr: %f, train_loss: %f,validation loss:%f "%(
    #                   e,
    #                   time.time()-start,
    #                   optimizer_my.state_dict()['param_groups'][0]['lr'],
    #                   loss.item()/len(train_data),
    #                   val_loss
    #              ))
            if running_loss<max_error:
                max_error=running_loss
        #         torch.save(net,train_loss"_"+str(e)+".pkl")
                if not os.path.exists("./model/save/"+args.train_name+"/"):
                    os.mkdir("./model/save/"+args.train_name+"/")
                write_log("saving")
                state = {'model': my_net.state_dict(), 'optimizer': optimizer_my.state_dict(), 'epoch': e}
                torch.save(state, "./model/save/temp01/"+str(e)+".pth")
Ejemplo n.º 15
0
@LastEditTime: 2019-11-19 14:02:36
'''
import torch

import utility
import data
import model
import loss

# import h5py
from option import args  # option.py 定义外部参数的获取样式
from trainer import Trainer

torch.set_num_threads(12)  # 设置 CPU 的多核心计算
torch.manual_seed(args.seed)  # 设置随机初始化种子,保证每次的初始化都相同
checkpoint = utility.checkpoint(args)  # 对程序传入的外部参数进行处理

if checkpoint.ok:

    # args.model = 'NL_EST'
    # model1 = model.Model(args, checkpoint)
    #
    # args.model = 'KERNEL_EST'
    # model2 = model.Model(args, checkpoint)
    # args.model = 'BSR'

    # 获取网络模型
    model = model.Model(args, checkpoint)
    # 导入相应的训练或测试数据集
    loader = data.Data(args)
    # 导入 loss function
Ejemplo n.º 16
0
    def load(self, args):

        if args.model_dir != "":
            loadedparams = torch.load(args.model_dir, map_location=self.device)
            self.agent = agent.Agent(args, chkpoint=loadedparams)
            #self.agent = agent.Agent(args)
        else:
            self.agent = agent.Agent(args)
        self.SRmodels = []
        self.SRoptimizers = []
        self.schedulers = []
        for i in range(args.action_space):

            #CREATE THE ARCH
            if args.model == 'basic':
                model = arch.RRDBNet(3,
                                     3,
                                     32,
                                     args.d,
                                     gc=8,
                                     upsize=args.upsize)
            elif args.model == 'ESRGAN':
                model = arch.RRDBNet(3, 3, 64, 23, gc=32, upsize=args.upsize)
            elif args.model == 'RCAN':
                torch.manual_seed(args.seed)
                checkpoint = utility.checkpoint(args)
                if checkpoint.ok:
                    module = import_module('model.rcan')
                    model = module.make_model(args).to(self.device)
                    kwargs = {}
                else:
                    print('error loading RCAN model. QUITING')
                    quit()

            #LOAD THE WEIGHTS
            if args.model_dir != "":
                model.load_state_dict(loadedparams["sisr" + str(i)])
                print('continuing training')
            elif args.random:
                print('random init')
                model.apply(init_weights)
            elif args.model == 'ESRGAN':
                loaded_dict = torch.load(args.ESRGAN_PATH)
                model_dict = model.state_dict()
                loaded_dict = {
                    k: v
                    for k, v in loaded_dict.items() if k in model_dict
                }
                model_dict.update(loaded_dict)
                model.load_state_dict(model_dict)
            elif args.model == 'RCAN':
                print('RCAN loaded!')
                model.load_state_dict(torch.load(args.pre_train, **kwargs),
                                      strict=True)
            elif args.model == 'basic':
                if args.d == 1:
                    loaded_dict = torch.load(args.basicpath_d1)
                elif args.d == 2:
                    loaded_dict = torch.load(args.basicpath_d2)
                elif args.d == 4:
                    loaded_dict = torch.load(args.basicpath_d4)
                elif args.d == 8:
                    loaded_dict = torch.load(args.basicpath_d8)
                else:
                    print(
                        'no pretrained model available. Random initialization of basic block'
                    )
                model_dict = model.state_dict()
                loaded_dict = {
                    k: v
                    for k, v in loaded_dict.items() if k in model_dict
                }
                model_dict.update(loaded_dict)
                model.load_state_dict(model_dict)

            self.SRmodels.append(model)
            self.SRmodels[-1].to(self.device)

            #self.SRoptimizers.append(torch.optim.Adam(model.parameters(),lr=1e-5))
            self.SRoptimizers.append(
                torch.optim.Adam(model.parameters(), lr=1e-5))
            scheduler = torch.optim.lr_scheduler.StepLR(self.SRoptimizers[-1],
                                                        1000,
                                                        gamma=0.5)

            self.schedulers.append(scheduler)
Ejemplo n.º 17
0
def main(argv=None):
    # ============Dataset===============
    print('Loading dataset...')
    train_set = SRDataset(args.train_dataset,
                          'train',
                          patch_size=args.patch_size,
                          num_repeats=args.num_repeats,
                          is_aug=True,
                          crop_type='random')
    val_set = SRDataset(args.valid_dataset,
                        'valid',
                        patch_size=None,
                        num_repeats=1,
                        is_aug=False,
                        fixed_length=10)
    #from ipdb import set_trace
    #set_trace()
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    # ============Model================
    n_GPUs = torch.cuda.device_count()
    print('Loading model using %d GPU(s)' % n_GPUs)

    opt = {
        'patch_size': args.patch_size,
        'num_channels': args.num_channels,
        'depth': args.num_blocks,
        'res_scale': args.res_scale,
        'spectral_norm': args.spectral_norm
    }
    ###################MODIFYED#########################
    torch.manual_seed(args1.seed)
    checkpoint = utility.checkpoint(args1)
    #from ipdb import set_trace
    #set_trace()
    G = model1.Model(args1, checkpoint)
    '''
    G = Generator(opt)
    if args.pretrained_model != '':
        print('Fetching pretrained model', args.pretrained_model)
        G.load_state_dict(torch.load(args.pretrained_model))
    '''
    ###################################################

    #G = nn.DataParallel(G).cuda()
    #from ipdb import set_trace
    #set_trace()
    D = nn.DataParallel(Discriminator(opt)).cuda()

    vgg = nn.DataParallel(VGG()).cuda()

    cudnn.benchmark = True

    #========== Optimizer============
    trainable = filter(lambda x: x.requires_grad, G.parameters())
    optim_G = optim.Adam(trainable, betas=(0.9, 0.999), lr=args.learning_rate)
    optim_D = optim.Adam(D.parameters(),
                         betas=(0.9, 0.999),
                         lr=args.learning_rate)
    scheduler_G = lr_scheduler.StepLR(optim_G,
                                      step_size=args.lr_step,
                                      gamma=0.5)
    scheduler_D = lr_scheduler.StepLR(optim_D,
                                      step_size=args.lr_step,
                                      gamma=0.5)

    # ============Loss==============
    l1_loss_fn = nn.L1Loss()
    bce_loss_fn = nn.BCEWithLogitsLoss()
    f_loss_fn = FocalLoss(args.fl_gamma)

    def vgg_loss_fn(output, label):
        vgg_sr, vgg_hr = vgg(output, label)
        return F.mse_loss(vgg_sr, vgg_hr)

    def tv_loss_fn(y):
        loss_var = torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + \
                   torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))
        return loss_var
        ##############################change###############################

    # ==========Logging and book-keeping=======
    check_point = os.path.join(args.check_point, args.phase)
    tb = SummaryWriter(check_point)
    best_psnr = 0

    # ==========GAN vars======================
    target_real = Variable(torch.Tensor(args.batch_size, 1).fill_(1.0),
                           requires_grad=False).cuda()
    target_fake = Variable(torch.Tensor(args.batch_size, 1).fill_(0.0),
                           requires_grad=False).cuda()

    # Training and validating
    for epoch in range(1, args.num_epochs + 1):

        #===========Pretrain===================
        if args.phase == 'pretrain':
            scheduler_G.step()
            cur_lr = optim_G.param_groups[0]['lr']
            print('Model {}. Epoch [{}/{}]. Learning rate: {}'.format(
                args.check_point, epoch, args.num_epochs, cur_lr))

            num_batches = len(train_set) // args.batch_size
            running_loss = 0

            for i, (inputs, labels) in enumerate(tqdm(train_loader)):
                lr, hr = (Variable(inputs.cuda()), Variable(labels.cuda()))

                sr = G(lr)
                optim_G.zero_grad()

                loss = l1_loss_fn(sr, hr)
                loss.backward()
                optim_G.step()

                # update log
                running_loss += loss.item()

            avr_loss = running_loss / num_batches
            tb.add_scalar('Learning rate', cur_lr, epoch)
            tb.add_scalar('Pretrain Loss', avr_loss, epoch)
            print('Finish train [%d/%d]. Loss: %.2f' %
                  (epoch, args.num_epochs, avr_loss))

        #===============Train======================
        else:
            scheduler_G.step()
            scheduler_D.step()
            cur_lr = optim_G.param_groups[0]['lr']
            print('Model {}. Epoch [{}/{}]. Learning rate: {}'.format(
                check_point, epoch, args.num_epochs, cur_lr))

            num_batches = len(train_set) // args.batch_size
            running_loss = np.zeros(5)

            for i, (inputs, labels) in enumerate(tqdm(train_loader)):
                #from ipdb import set_trace
                #set_trace()
                lr, hr = (Variable(inputs.cuda()), Variable(labels.cuda()))

                #################changed####################

                def input_matrix_wpn(inH, inW, scale, add_scale=True):

                    outH, outW = int(scale * inH), int(scale * inW)

                    #### mask records which pixel is invalid, 1 valid or o invalid

                    #### h_offset and w_offset caculate the offset to generate the input matrix
                    scale_int = int(math.ceil(scale))
                    h_offset = torch.ones(inH, scale_int, 1)
                    mask_h = torch.zeros(inH, scale_int, 1)
                    w_offset = torch.ones(1, inW, scale_int)
                    mask_w = torch.zeros(1, inW, scale_int)
                    if add_scale:
                        scale_mat = torch.zeros(1, 1)
                        scale_mat[0, 0] = 1.0 / scale
                        #res_scale = scale_int - scale
                        #scale_mat[0,scale_int-1]=1-res_scale
                        #scale_mat[0,scale_int-2]= res_scale
                        scale_mat = torch.cat([scale_mat] * (inH * inW *
                                                             (scale_int**2)),
                                              0)  ###(inH*inW*scale_int**2, 4)

                    ####projection  coordinate  and caculate the offset
                    h_project_coord = torch.arange(0, outH,
                                                   1).float().mul(1.0 / scale)
                    int_h_project_coord = torch.floor(h_project_coord)

                    offset_h_coord = h_project_coord - int_h_project_coord
                    int_h_project_coord = int_h_project_coord.int()

                    w_project_coord = torch.arange(0, outW,
                                                   1).float().mul(1.0 / scale)
                    int_w_project_coord = torch.floor(w_project_coord)

                    offset_w_coord = w_project_coord - int_w_project_coord
                    int_w_project_coord = int_w_project_coord.int()

                    ####flag for   number for current coordinate LR image
                    flag = 0
                    number = 0
                    for i in range(outH):
                        if int_h_project_coord[i] == number:
                            h_offset[int_h_project_coord[i], flag,
                                     0] = offset_h_coord[i]
                            mask_h[int_h_project_coord[i], flag, 0] = 1
                            flag += 1
                        else:
                            h_offset[int_h_project_coord[i], 0,
                                     0] = offset_h_coord[i]
                            mask_h[int_h_project_coord[i], 0, 0] = 1
                            number += 1
                            flag = 1

                    flag = 0
                    number = 0
                    for i in range(outW):
                        if int_w_project_coord[i] == number:
                            w_offset[0, int_w_project_coord[i],
                                     flag] = offset_w_coord[i]
                            mask_w[0, int_w_project_coord[i], flag] = 1
                            flag += 1
                        else:
                            w_offset[0, int_w_project_coord[i],
                                     0] = offset_w_coord[i]
                            mask_w[0, int_w_project_coord[i], 0] = 1
                            number += 1
                            flag = 1

                    ## the size is scale_int* inH* (scal_int*inW)
                    h_offset_coord = torch.cat([h_offset] * (scale_int * inW),
                                               2).view(-1, scale_int * inW, 1)
                    w_offset_coord = torch.cat([w_offset] * (scale_int * inH),
                                               0).view(-1, scale_int * inW, 1)
                    ####
                    mask_h = torch.cat([mask_h] * (scale_int * inW),
                                       2).view(-1, scale_int * inW, 1)
                    mask_w = torch.cat([mask_w] * (scale_int * inH),
                                       0).view(-1, scale_int * inW, 1)

                    pos_mat = torch.cat((h_offset_coord, w_offset_coord), 2)
                    mask_mat = torch.sum(torch.cat((mask_h, mask_w), 2),
                                         2).view(scale_int * inH,
                                                 scale_int * inW)
                    mask_mat = mask_mat.eq(2)
                    pos_mat = pos_mat.contiguous().view(1, -1, 2)
                    if add_scale:
                        pos_mat = torch.cat(
                            (scale_mat.view(1, -1, 1), pos_mat), 2)

                    return pos_mat, mask_mat  ##outH*outW*2 outH=scale_int*inH , outW = scale_int *inW
                    ############################################

                N, C, H, W = lr.size()
                _, _, outH, outW = hr.size()
                #from ipdb import set_trace
                #set_trace()
                scale_coord_map, mask = input_matrix_wpn(H, W, args1.scale[0])

                if args1.n_GPUs > 1:
                    scale_coord_map = torch.cat([scale_coord_map] *
                                                args1.n_GPUs, 0)
                else:
                    scale_coord_map = scale_coord_map.cuda()
                #init_sr = G(lr,0,scale_coord_map)

                #######################################
                # Discriminator
                # hr: real, sr: fake
                #######################################

                for p in D.parameters():
                    p.requires_grad = True
                optim_D.zero_grad()
                #from ipdb import set_trace
                #set_trace()
                pred_real = D(hr)
                ###################For SR#####################

                init_sr = G(lr, 0, scale_coord_map)
                pa_sr = torch.masked_select(init_sr, mask.cuda())
                sr = pa_sr.contiguous().view(N, C, outH, outW)
                ##############################################
                pred_fake = D(sr.detach())

                if args.gan_type == 'SGAN':
                    total_D_loss = bce_loss_fn(pred_real,
                                               target_real) + bce_loss_fn(
                                                   pred_fake, target_fake)
                elif args.gan_type == 'RSGAN':
                    total_D_loss = bce_loss_fn(pred_real - pred_fake,
                                               target_real)

                # gradient penalty
                if args.GP:
                    grad_outputs = torch.ones(args.batch_size, 1).cuda()
                    u = torch.FloatTensor(args.batch_size, 1, 1, 1).cuda()
                    u.uniform_(0, 1)
                    x_both = (hr * u + sr * (1 - u)).cuda()
                    x_both = Variable(x_both, requires_grad=True)
                    grad = torch.autograd.grad(outputs=D(x_both),
                                               inputs=x_both,
                                               grad_outputs=grad_outputs,
                                               retain_graph=True,
                                               create_graph=True,
                                               only_inputs=True)[0]
                    grad_penalty = 10 * (
                        (grad.norm(2, 1).norm(2, 1).norm(2, 1) - 1)**2).mean()
                    total_D_loss = total_D_loss + grad_penalty

                total_D_loss.backward()
                optim_D.step()

                ######################################
                # Generator
                ######################################
                for p in D.parameters():
                    p.requires_grad = False
                optim_G.zero_grad()
                pred_fake = D(sr)
                pred_real = D(hr)

                l1_loss = l1_loss_fn(sr, hr) * args.alpha_l1
                vgg_loss = vgg_loss_fn(sr, hr) * args.alpha_vgg
                tv_loss = tv_loss_fn(sr) * args.alpha_tv

                if args.gan_type == 'SGAN':
                    if args.focal_loss:
                        G_loss = f_loss_fn(pred_fake, target_real)
                    else:
                        G_loss = bce_loss_fn(pred_fake, target_real)
                elif args.gan_type == 'RSGAN':
                    if args.focal_loss:
                        G_loss = f_loss_fn(pred_fake - pred_real,
                                           target_real)  #Focal loss
                    else:
                        G_loss = bce_loss_fn(pred_fake - pred_real,
                                             target_real)
                G_loss = G_loss * args.alpha_gan

                total_G_loss = l1_loss + vgg_loss + G_loss + tv_loss

                total_G_loss.backward()
                optim_G.step()

                # update log
                running_loss += [
                    l1_loss.item(),
                    vgg_loss.item(),
                    G_loss.item(),
                    tv_loss.item(),
                    total_D_loss.item()
                ]

            avr_loss = running_loss / num_batches
            tb.add_scalar('Learning rate', cur_lr, epoch)
            tb.add_scalar('L1 Loss', avr_loss[0], epoch)
            tb.add_scalar('VGG Loss', avr_loss[1], epoch)
            tb.add_scalar('G Loss', avr_loss[2], epoch)
            tb.add_scalar('TV Loss', avr_loss[3], epoch)
            tb.add_scalar('D Loss', avr_loss[4], epoch)
            tb.add_scalar('Total G Loss', avr_loss[0:4].sum(), epoch)
            print('Finish train [%d/%d]. L1: %.2f. VGG: %.2f. G: %.2f. TV: %.2f. Total G: %.2f. D: %.2f'\
                  %(epoch, args.num_epochs, avr_loss[0], avr_loss[1], avr_loss[2],
                    avr_loss[3], avr_loss[0:4].sum(), avr_loss[4]))
            if epoch % args.snapshot_every == 0:
                model_path = os.path.join(check_point,
                                          'model_{}.pt'.format(epoch))
                torch.save(G.state_dict(), model_path)
                print('Saved snapshot model.')
        #===============Validate================
        '''
Ejemplo n.º 18
0
import torch

import utility
import data
import model
import loss
from option import args
from trainer import Trainer
import time

start_time = time.time()
torch.manual_seed(args.seed)

checkpoint = utility.checkpoint(args, 'model1')
if (args.nmodels == 2):
    checkpoint2 = utility.checkpoint(args, 'model2')

if checkpoint.ok:
    loader = data.Data(args)
    model1 = model.Model(args, checkpoint, 'model1')
    loss1 = loss.Loss(args, checkpoint,
                      'model1') if not args.test_only else None
    if (args.nmodels == 2):
        model2 = model.Model(args, checkpoint2, 'model2')
        loss2 = loss.Loss(args, checkpoint2,
                          'model2') if not args.test_only else None
        print("calling 2 model trainer")
        t = Trainer(args, loader, model1, loss1, checkpoint, model2, loss2,
                    checkpoint2)
    else:
        t = Trainer(args, loader, model1, loss1, checkpoint)
Ejemplo n.º 19
0
import torch
import os

import utility
import data
import model
import loss
from option import args
from torch.nn import DataParallel

torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)

print(os.getcwd())

if checkpoint.ok:
    loader = data.Data(args)
    model = model.Model(args, checkpoint)
    model = DataParallel(model)
    loss = loss.Loss(args, checkpoint) if not args.test_only else None
    t = Trainer(args, loader, model, loss, checkpoint)
    while not t.terminate():
        t.train()
    #     t.test()

    checkpoint.done()

Ejemplo n.º 20
0
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))

        self.error_last = 1e8

        self.is_upsample = False
        up_model_list = ('mwcnn', 'vdsr', 'docnn', 'mwcnn_caa', 'mwcnn_cab', \
                         'mwcnn_caab', 'docnn_cab')
        for model in up_model_list:
            if self.args.model == model:
                self.is_upsample = True
                break

        self.is_pad = False
        up_model_list = ('mwcnn', 'docnn', 'mwcnn_caa', 'mwcnn_cab', \
                         'mwcnn_caab', 'docnn_cab')
        for model in up_model_list:
            if self.args.model == model:
                self.is_pad = True
                break
        #args.save2 = args.save
        args2 = args
        args2.resume = -2
        args2.mid_channels = 4
        args2.model = args.model_init

        if not args2.resume == -2:
            #args2.model = args.model_init
            args2.resume = -2
            args2.mid_channels = 4
            #args2.batch_size = 32
            args2.sigma = 10
            #args.loss = '1*L1'
            args2.save = args2.model + '_mid' + str(
                args2.mid_channels) + '_sb' + str(
                    args2.batch_size) + '_sig' + str(args2.sigma)
            if args2.is_act:
                args2.save = args2.save + '_PreLU'
            else:
                args2.save = args2.save + '_Linear'

            note = ''
            for loss in args2.loss.split('+'):
                weight, loss_type = loss.split('*')
                note = note + '_' + str(weight) + loss_type

            args2.save = args2.save + note
            args2.pre_train = '../experiment/' + args2.save + '/model/model_best.pt'
        else:
            args2.pre_train = '../experiment/' + args2.save + '/model_init/model_best.pt'

        checkpoint = utility.checkpoint(args2)
        self.model_init = M.Model(args2, checkpoint)
        self.optimizer_init = utility.make_optimizer(args2, self.model_init)
        self.init_psnr = 0