def main(): global opt, model opt = parser.parse_args() opt.gpus = '0' print(opt) opt.cuda = True os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus cuda = opt.cuda if cuda and not torch.cuda.is_available(): raise Exception("No GPU found, please run without --cuda") opt.seed = random.randint(1, 10000) torch.manual_seed(opt.seed) if cuda: torch.cuda.manual_seed(opt.seed) cudnn.benchmark = True print("===> Loading datasets") train_set = DatasetFromHdf5("./data/training_RGB_5to50_uint8_samples.h5") training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True, pin_memory=True) print("===> Building model") model = _NetG() criterion = nn.L1Loss() print("===> Setting GPU") if cuda: model = model.cuda() criterion = criterion.cuda() summary(model, (3, 64, 64)) # optionally resume from a checkpoint if opt.resume: if os.path.isfile(opt.resume): print("=> loading checkpoint '{}'".format(opt.resume)) checkpoint = torch.load(opt.resume, map_location=lambda storage, loc: storage) opt.start_epoch = checkpoint["epoch"] model.load_state_dict(checkpoint['model'].state_dict()) del checkpoint torch.cuda.empty_cache() else: print("=> no checkpoint found at '{}'".format(opt.resume)) print("===> Setting Optimizer") optimizer = optim.Adam(model.parameters(), lr=opt.lr) print("===> Training") max_psnr = 0 for epoch in range(opt.start_epoch, opt.nEpochs + 1): max_psnr = train(training_data_loader, optimizer, model, criterion, epoch, max_psnr) save_checkpoint(model, epoch, 'end', 'end_ep')
import torch from color_model import _NetG model1_path = "./checkpoint/pretrained_color/model_31.82db_11ep_32000it_.pth" model2_path = "./checkpoint/pretrained_color/model_31.85db_17ep_20000it_.pth" model1 = torch.load(model1_path, map_location=lambda storage, loc: storage)["model"] model2 = torch.load(model2_path, map_location=lambda storage, loc: storage)["model"] beta = 0.5 #The interpolation parameter params1 = model1.named_parameters() params2 = model2.named_parameters() dict_params2 = dict(params2) for name1, param1 in params1: if name1 in dict_params2: dict_params2[name1].data.copy_(beta * param1.data + (beta) * dict_params2[name1].data) model = _NetG() model.load_state_dict(dict_params2) model_out_path = "checkpoint/" + "color_model.pth" state = {"epoch": 0, "model": model} torch.save(state, model_out_path)