Esempio n. 1
0
from utils import *

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights_file', type=str, required=True)
    parser.add_argument('--image_file', type=str, required=True)
    parser.add_argument('--realimage_file', type=str, required=True)
    parser.add_argument('--scale', type=int, default=1)
    args = parser.parse_args()

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = VDSR().to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    # model, optim = torch.load(model.state_dict(), os.path.join(args.weights_file, 'epoch_150.pth'))

    model.eval()

    image = Image.open(args.realimage_file).convert('RGB')
    resample = Image.open(args.image_file).convert('RGB')

    image_width = (image.width // args.scale) * args.scale
Esempio n. 2
0
criterion=criterion.to(device)

net.train()
for epoch in range(20):

    running_cost=0.0
    for i,data in enumerate (trainLoader,0):
        input,target=data
        input,target=input.to(device),target.to(device)
        optimizer.zero_grad()
        output=net(input)
        loss=criterion(output,target)
        loss.backward()
        if optimizer=='SGD':
            nn.utils.clip_grad_norm(net.parameters(),0.4)
        optimizer.step()
        running_cost+=loss.item()
        torch.save(net.state_dict(),'VDSR.pth')


        if i%10 == 9 :
            print('epoch:%d, loss:%.8f'%(epoch,running_cost/10))
            writer.add_scalar('loss', running_cost/10,epoch*len(trainLoader)+i)
            running_cost=0.0
    scheduler.step()

writer.close()