Exemplo n.º 1
0
train_set = get_training_set(opt.data_dir, opt.upscale_factor, opt.patch_size,
                             opt.data_augmentation)
training_data_loader = DataLoader(dataset=train_set,
                                  num_workers=opt.threads,
                                  batch_size=opt.batchSize,
                                  shuffle=True)

print('===> Building model')

denoiser = VAE_denoise_vali(input_dim=3,
                            dim=32,
                            feat_size=8,
                            z_dim=512,
                            prior='standard',
                            number_component=512)
G = VAE_SR(input_dim=3, dim=64, scale_factor=opt.upscale_factor)
D = discriminator(num_channels=3,
                  base_filter=64,
                  image_size=opt.patch_size * opt.upscale_factor)
feat_extractor = VGGFeatureExtractor(feature_layer=34,
                                     use_bn=False,
                                     use_input_norm=True,
                                     device='cuda')

denoiser = torch.nn.DataParallel(denoiser, device_ids=gpus_list)
G = torch.nn.DataParallel(G, device_ids=gpus_list)
D = torch.nn.DataParallel(D, device_ids=gpus_list)
feat_extractor = torch.nn.DataParallel(feat_extractor, device_ids=gpus_list)

L1_loss = nn.L1Loss()
BCE_loss = nn.BCEWithLogitsLoss()
Exemplo n.º 2
0
cuda = opt.gpu_mode
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)

print('===> Building model ', opt.model_type)

denoiser = VAE_denoise_vali(input_dim=3,
                            dim=32,
                            feat_size=8,
                            z_dim=512,
                            prior='standard')
model = VAE_SR(input_dim=3, dim=64, scale_factor=opt.upscale_factor)

denoiser = torch.nn.DataParallel(denoiser, device_ids=gpus_list)
model = torch.nn.DataParallel(model, device_ids=gpus_list)
if cuda:
    denoiser = denoiser.cuda(gpus_list[0])
    model = model.cuda(gpus_list[0])

print('===> Loading datasets')

if os.path.exists(opt.model_denoiser):
    # denoiser.load_state_dict(torch.load(opt.model_denoiser, map_location=lambda storage, loc: storage))
    pretrained_dict = torch.load(opt.model_denoiser,
                                 map_location=lambda storage, loc: storage)
    model_dict = denoiser.state_dict()
    pretrained_dict = {