if not os.path.exists(save_dir): os.mkdir(save_dir) if __name__ == '__main__': # model selection print('===> Building model') decompose_model = DnCNN(image_channels=5) # compose_model = ComposeNet(n_block=32) initial_epoch = findLastCheckpoint( save_dir=save_dir) # load the last model in matconvnet style # initial_epoch = 1 if initial_epoch > 0: print('resuming by loading epoch %03d' % initial_epoch) decompose_model.load_state_dict( torch.load( os.path.join(save_dir, 'decom_model_%03d.pth' % initial_epoch))) # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)) # criterion = nn.MSELoss(reduction = 'sum') # PyTorch 0.4.1 # criterion = sum_squared_error() criterion = nn.MSELoss() if cuda: decompose_model = decompose_model.cuda() # compose_model = compose_model.cuda() # device_ids = [0] # model = nn.DataParallel(model, device_ids=device_ids).cuda() # criterion = criterion.cuda() optimizer_decompose = optim.Adam(decompose_model.parameters(), lr=args.lr)
(torch.tensor(np.real(np.fft.ifft2(Low_freq))).unsqueeze(1).float(), torch.tensor(np.imag(np.fft.ifft2(Low_freq))).unsqueeze(1).float()), dim=1) return High_output, Low_output if not os.path.exists(save_dir): os.mkdir(save_dir) if __name__ == '__main__': # model selection print('===> Building model') model = DnCNN(image_channels=2) u_model = UNet(input_channels=1, image_channels=1) model.load_state_dict( torch.load(os.path.join(args.load_model_dir, args.load_model_name))) initial_epoch = findLastCheckpoint( save_dir=save_dir) # load the last model in matconvnet style # initial_epoch = 150 if initial_epoch > 0: print('resuming by loading epoch %03d' % initial_epoch) u_model.load_state_dict( torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))) # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)) model.eval() u_model.train() criterion = nn.MSELoss()
if __name__ == '__main__': args = parse_args() high_model = DnCNN() low_model = DnCNN() if not os.path.exists( os.path.join(args.high_model_dir, args.high_model_name)): high_model = torch.load(os.path.join(args.high_model_dir, 'model.pth')) # load weights into new model log('load trained model on Train400 dataset by kai') else: high_model.load_state_dict( torch.load(os.path.join(args.high_model_dir, args.high_model_name))) low_model.load_state_dict( torch.load(os.path.join(args.low_model_dir, args.low_model_name))) # model = torch.load(os.path.join(args.model_dir, args.model_name)) log('load trained model') # params = model.state_dict() # print(params.values()) # print(params.keys()) # # for key, value in params.items(): # print(key) # parameter name # print(params['dncnn.12.running_mean']) # print(model.state_dict())