コード例 #1
0
                        help='random seed to use. Default=123')
    parser.add_argument('--snapshot',
                        type=str,
                        default=None,
                        help='filename of model snapshot [default: None]')
    opt = parser.parse_args()
    print(opt)

    device_ids = [0, 1]
    device = torch.device("cuda")

    torch.manual_seed(opt.seed)

    print('===> Loading datasets')
    train_set = get_training_set()
    val_set = get_val_set()
    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=opt.threads,
                                      batch_size=opt.batchSize,
                                      shuffle=True)
    validating_data_loader = DataLoader(dataset=val_set,
                                        num_workers=opt.threads,
                                        batch_size=opt.valBatchSize,
                                        shuffle=False)

    print('===> Building model')

    model = Net()

    if opt.snapshot is not None:
        print('Loading model from {}...'.format(opt.snapshot))
コード例 #2
0
                    help='continue train model : load g model')
opt = parser.parse_args()
print(opt)

cuda = opt.cuda
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)
device = torch.device("cuda" if opt.cuda else "cpu")

print('===> Loading datasets')
train_set = get_training_set(opt.dataset)
val_set = get_val_set(opt.dataset)
training_data_loader = DataLoader(dataset=train_set,
                                  batch_size=opt.batchSize,
                                  shuffle=True)
val_data_loader = DataLoader(dataset=val_set,
                             batch_size=opt.valBatchSize,
                             shuffle=False)

print('===> Loading pre_train model and Building model')
model_r = LapSRN_r().to(device)
model_g = LapSRN_g().to(device)
Loss = Loss()
criterion = nn.MSELoss()
if cuda:
    Loss = Loss.cuda()
    criterion = criterion.cuda()
コード例 #3
0
    os.makedirs(os.path.join(config['training']['checkpoint_folder'],
                             os.path.basename(args.config)[:-5]),
                exist_ok=True)

    print('===> Loading datasets')
    sys.stdout.flush()
    train_set = get_training_set(
        img_dir=config['data']['train_root'],
        upscale_factor=config['model']['upscale_factor'],
        crop_size=config['data']['lr_crop_size'] *
        config['model']['upscale_factor'])
    train_dataloader = DataLoader(dataset=train_set,
                                  batch_size=config['training']['batch_size'],
                                  shuffle=True)

    val_set = get_val_set(img_dir=config['data']['test_root'],
                          upscale_factor=config['model']['upscale_factor'])
    val_dataloader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

    print('===> Building model')
    sys.stdout.flush()
    model = SRCNN().to(device)
    criterion = nn.MSELoss()
    optimizer = setup_optimizer(model, config)
    scheduler = setup_scheduler(optimizer, config)

    start_iter = 0
    best_val_psnr = -1

    if config['training']['resume'] != 'None':
        print('===> Reloading model')
        sys.stdout.flush()
コード例 #4
0
print(opt)

if opt.cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

cudnn.benchmark = True

torch.manual_seed(opt.seed)
if opt.cuda:
    torch.cuda.manual_seed(opt.seed)

print('===> Loading datasets')
root_path = "dataset/"
train_set = get_training_set(root_path + opt.dataset)
val_set = get_val_set(root_path + opt.dataset)

training_data_loader = DataLoader(dataset=train_set,
                                  num_workers=opt.threads,
                                  batch_size=opt.batch_size,
                                  shuffle=True)
val_data_loader = DataLoader(dataset=val_set,
                             num_workers=opt.threads,
                             batch_size=opt.test_batch_size,
                             shuffle=False)

device = torch.device("cuda:0" if opt.cuda else "cpu")

print('===> Building models')
net_g = define_G(opt.input_nc,
                 opt.output_nc,