示例#1
0
def get_loaders(batch_size,device):
    data_root = 'ceng483-s19-hw3-dataset' 
    train_set = hw3utils.HW3ImageFolder(root=os.path.join(data_root,'train'),device=device)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    val_set = hw3utils.HW3ImageFolder(root=os.path.join(data_root,'val'),device=device)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
    # Note: you may later add test_loader to here.
    test_set = hw3utils.HW3ImageFolder(root=os.path.join(data_root,'test'),device=device)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
    return train_loader, val_loader, test_loader
def test_loader(batch_size, device):
    data_root = 'ceng483-s19-hw3-dataset'
    test_set = hw3utils.HW3ImageFolder(root=os.path.join(data_root, 'test'),
                                       device=device)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=0)
    return test_loader
示例#3
0
    f.write("Average Loss = " + str(avg_loss) + "\n")
    f.write("Average ACC = " + str(avg_acc) + "\n")
    f.write("+--------------------+")
if i % 2 == 1:
    batch_val += 1
if i % 6 == 5:
    batch_val = 0
    lr_val += 1'''

print('Finished Training')

print("Estimations.npy Creating")
#net.load_state_dict(os.path.join(LOG_DIR,'checkpoint.pt'))

data_root = 'ceng483-s19-hw3-dataset'
test_set = hw3utils.HW3ImageFolder(root=os.path.join(data_root, 'test'),
                                   device=device)
test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=0)

estimations2 = []
for iteri, data in enumerate(test_loader, 0):
    inputs, targets = data  # inputs: low-resolution images, targets: high-resolution images.
    preds = net(inputs)
    for i in range(len(preds)):
        est = (((preds[i].permute(1, 2, 0).cpu().detach().numpy()) / 2) +
               0.5) * 255
        estimations2.append(est)

estimations2 = np.array(estimations)