Example #1
0
opt.seed = random.randint(1, 1200)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)

if opt.resume:
    if os.path.isfile(opt.resume):
        print("Loading from checkpoint {}".format(opt.resume))
        model = torch.load(opt.resume)
        model.load_state_dict(model.state_dict())
        netD = torch.load(opt.resumeD)
        netD.load_state_dict(netD.state_dict())
        opt.start_training_step, opt.start_epoch = which_trainingstep_epoch(
            opt.resume)

else:
    model = Net()
    netD = Discriminator()
    mkdir_steptraing()

# model = torch.load('models/1/GFN_epoch_1.pkl')
# model.load_state_dict(model.state_dict())
# netD = torch.load('models/1/GFN_D_epoch_1.pkl')
# netD.load_state_dict(netD.state_dict())

model = model.to(device)
netD = netD.to(device)
criterion = torch.nn.MSELoss(size_average=True)
criterion = criterion.to(device)
cri_perception = VGGFeatureExtractor().to(device)
cri_gan = GANLoss('vanilla', 1.0, 0.0).to(device)
Example #2
0
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)



if opt.resume:
    if os.path.isfile(opt.resume):
        print("Loading from checkpoint {}".format(opt.resume))
        model = torch.load(opt.resume)
        model.load_state_dict(model.state_dict())
        netD = torch.load(opt.resumeD)
        netD.load_state_dict(netD.state_dict())
        opt.start_training_step, opt.start_epoch = which_trainingstep_epoch(opt.resume)

else:
    model = Net()
    netD = Discriminator()
    mkdir_steptraing()

model = model.to(device)
netD = netD.to(device)
criterion = torch.nn.L1Loss(size_average=True)
criterion = criterion.to(device)
cri_perception = VGGFeatureExtractor().to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 0.0001)
optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, netD.parameters()), 0.0004)
print()


for i in range(opt.start_training_step, 4):
    opt.nEpochs   = training_settings[i-1]['nEpochs']
opt = parser.parse_args()
opt.seed = random.randint(1, 1200)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)

if opt.resume:
    if os.path.isfile(opt.resume):
        print("Loading from checkpoint {}".format(opt.resume))
        model = torch.load(opt.resume)
        model.load_state_dict(model.state_dict())
        opt.start_training_step, opt.start_epoch = which_trainingstep_epoch(
            opt.resume)

else:
    model = Net()
    mkdir_steptraing()

model = model.to(device)
criterion = torch.nn.MSELoss(size_average=True)
criterion = criterion.to(device)
cri_perception = VGGFeatureExtractor().to(device)
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), 0.0001,
    [0.9, 0.999])
print()

for i in range(opt.start_training_step, 4):
    opt.nEpochs = training_settings[i - 1]['nEpochs']
    opt.lr = training_settings[i - 1]['lr']
    opt.step = training_settings[i - 1]['step']
Example #4
0
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)

train_dir = opt.dataset

train_sets = [x for x in sorted(os.listdir(train_dir)) if is_hdf5_file(x)]
print("===> Loading model and criterion")

if opt.resume:
    if os.path.isfile(opt.resume):
        print("Loading from checkpoint {}".format(opt.resume))
        model = torch.load(opt.resume)
        model.load_state_dict(model.state_dict())
        opt.start_training_step, opt.start_epoch = which_trainingstep_epoch(opt.resume)
else:
    model = Net()
    mkdir_steptraing()

model = model.to(device)
criterion = torch.nn.MSELoss(size_average=True)
criterion = criterion.to(device)
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
print()

for i in range(opt.start_training_step, 4):
    opt.nEpochs   = training_settings[i-1]['nEpochs']
    opt.lr        = training_settings[i-1]['lr']
    opt.step      = training_settings[i-1]['step']
    opt.lr_decay  = training_settings[i-1]['lr_decay']
    opt.lambda_db = training_settings[i-1]['lambda_db']
    opt.gated     = training_settings[i-1]['gated']