示例#1
0
文件: train.py 项目: xuexiy1ge/AI4K
def hand_png_file(args, input_trans, target_trans):
    input_files = glob.glob(os.path.join(args.input, "*"))
    target_files = glob.glob(os.path.join(args.target, "*"))

    if len(input_files) != len(target_files):
        raise Exception('两边的文件数量不相等', len(input_files), len(target_files))
    input_files = np.array(input_files)
    target_files = np.array(target_files)
    train_input = input_files[0:int(len(input_files) * args.train_val_ratio)]
    train_target = target_files[
        0:int(len(target_files) * args.train_val_ratio)]
    val_input = input_files[int(len(input_files) * args.train_val_ratio):]
    val_target = target_files[int(len(target_files) * args.train_val_ratio):]

    print(len(train_input), len(train_target), len(val_input), len(val_target))
    train_set = DatasetFromFolder(train_input,
                                  train_target,
                                  input_transform=input_trans,
                                  target_transform=target_trans)
    val_set = DatasetFromFolder(val_input,
                                val_target,
                                input_transform=input_trans,
                                target_transform=target_trans)
    train_loader = DataLoader(dataset=train_set,
                              num_workers=args.num_workers,
                              batch_size=args.batch_size,
                              drop_last=True,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=args.num_workers,
                            batch_size=args.batch_size,
                            drop_last=True,
                            shuffle=True)
    return train_set, val_set, train_loader, val_loader
示例#2
0
def get_test_set(dataset, crop_size, upscale_factor):
    test_dir = join("dataset", dataset)
    cropsize = calculate_valid_crop_size(crop_size, upscale_factor)

    return DatasetFromFolder(test_dir,
                             input_transform=input_transform(
                                 cropsize, upscale_factor),
                             target_transform=target_transform(cropsize))
示例#3
0
def get_test_set(upscale_factor):
    test_dir = "dataset/Urban100"
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(test_dir,
                             input_transform=input_transform(
                                 crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))
示例#4
0
def get_validation_set(dataset, crop_size, upscale_factor):
    root_dir = join("dataset", dataset)
    validation_dir = join(root_dir, "valid")
    cropsize = calculate_valid_crop_size(crop_size, upscale_factor)

    return DatasetFromFolder(validation_dir,
                             input_transform=input_transform(
                                 cropsize, upscale_factor),
                             target_transform=target_transform(cropsize))
示例#5
0
def get_test_set(upscale_factor):
    root_dir = download_bsd300()
    test_dir = join(root_dir, "test")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(test_dir,
                             input_transform=input_transform(
                                 crop_size, upscale_factor),
                             target_transform=target_transform(crop_size))
示例#6
0
文件: data.py 项目: curlyqian/SVLRM
def get_training_set(dataset, upscale_factor=4, crop=None):
    root_dir = join("dataset", dataset)
    train_dir = join(root_dir, "RGBD_data")

    return DatasetFromFolder(
        train_dir,
        upscale_factor=upscale_factor,
        crop=crop,
    )
示例#7
0
def get_training_set(dataset, crop_size, upscale_factor, add_noise=None, noise_std=3.0):

    cropsize = calculate_valid_crop_size(crop_size, upscale_factor)

    return DatasetFromFolder(dataset,
                             input_transform=input_transform(
                                 cropsize, upscale_factor),
                             target_transform=target_transform(cropsize),
                             add_noise=add_noise,
                             noise_std=noise_std)
示例#8
0
def get_training_set(upscale_factor, add_noise=None, noise_std=3.0):
    root_dir = download_bsd300()

    train_dir = join(root_dir, "train")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return DatasetFromFolder(train_dir,
                             input_transform=input_transform(
                                 crop_size, upscale_factor),
                             target_transform=target_transform(crop_size),
                             add_noise=add_noise,
                             noise_std=noise_std)
示例#9
0
def get_training_set(dataset,
                     crop_size,
                     upscale_factor,
                     add_noise=None,
                     noise_std=3.0):
    root_dir = join("/data/zihaosh", dataset)
    train_dir = join(root_dir, "train")
    cropsize = calculate_valid_crop_size(crop_size, upscale_factor)

    return DatasetFromFolder(train_dir,
                             input_transform=input_transform(
                                 cropsize, upscale_factor),
                             target_transform=target_transform(cropsize),
                             add_noise=add_noise,
                             noise_std=noise_std)
示例#10
0
                        default=100,
                        type=int,
                        help='super resolution epochs number')
    parser.add_argument('--dataset_name',
                        default="VOC2012",
                        type=str,
                        help='data set name')
    opt = parser.parse_args()

    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs
    DATASET_NAME = opt.dataset_name

    train_set = DatasetFromFolder('data/train',
                                  upscale_factor=UPSCALE_FACTOR,
                                  dataset_name=DATASET_NAME,
                                  input_transform=transforms.ToTensor(),
                                  target_transform=transforms.ToTensor())
    val_set = DatasetFromFolder('data/val',
                                upscale_factor=UPSCALE_FACTOR,
                                dataset_name=DATASET_NAME,
                                input_transform=transforms.ToTensor(),
                                target_transform=transforms.ToTensor())
    train_loader = DataLoader(dataset=train_set,
                              num_workers=8,
                              batch_size=128,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            num_workers=8,
                            batch_size=128,
                            shuffle=False)
示例#11
0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

opt = parser.parse_args()

crop_size = opt.crop_size
upscale_factor = opt.upscale_factor
batch_size = 8
num_epochs = 100
alpha = opt.alpha
train_data_path = 'data/train'

print("Training SR GAN with alpha: {}, batch size: {}, crop size: {}, train data: {} for epochs: {}".format(\
    alpha, batch_size, crop_size, train_data_path, num_epochs))

train_set = DatasetFromFolder(train_data_path, crop_size, alpha=alpha)
train_loader = DataLoader(train_set,
                          num_workers=0,
                          batch_size=batch_size,
                          shuffle=True)

netG = Generator(upscale_factor)
netD = Discriminator()

gen_criterion = GeneratorLoss()

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

if torch.cuda.is_available():
    netG.cuda()
示例#12
0
文件: train.py 项目: bonbert81/RTSR4k
    net = Net(upscale_factor=UPSCALE_FACTOR)
    print(net)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Running on', device)
    if device == 'cuda':
        net.cuda()
    transform = transforms.Compose([
        # you can add other transformations in this list
        transforms.ToTensor()
    ])

    # trainset = torchvision.datasets.ImageFolder(root = './data/train/SRF_3', transform=transforms.ToTensor(),
    #                                  target_transform=None)

    trainset = DatasetFromFolder('data/train',
                                 upscale_factor=UPSCALE_FACTOR,
                                 input_transform=transforms.ToTensor(),
                                 target_transform=transforms.ToTensor())

    testset = DatasetFromFolder('data/val',
                                upscale_factor=UPSCALE_FACTOR,
                                input_transform=transforms.ToTensor(),
                                target_transform=transforms.ToTensor())

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=2)

    # testset = torchvision.datasets.ImageFolder(root = './data/val/SRF_3', transform=transform,
    #                                  target_transform=None)
示例#13
0
文件: data.py 项目: curlyqian/SVLRM
def get_test_set(dataset, upscale_factor=4):
    root_dir = join("dataset", dataset)
    test_dir = join(root_dir, "RGBD_testdata")

    return DatasetFromFolder(test_dir, upscale_factor=upscale_factor)
示例#14
0
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from data_utils import DatasetFromFolder
from tensorboardX import SummaryWriter
from model import VDSR

device=torch.device('cuda:0')
writer=SummaryWriter('D:/VDSR')

transform=T.ToTensor()

trainset=DatasetFromFolder('D:/train_data/291',transform=transform)
trainLoader=DataLoader(trainset,batch_size=128,shuffle=True)


net=VDSR()
net=net.to(device)

optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9,weight_decay=1e-4)
scheduler=optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.1)
criterion=nn.MSELoss()
criterion=criterion.to(device)

net.train()
for epoch in range(20):

    running_cost=0.0