Exemplo n.º 1
0
def main():
    model = None

    model = WGAN_GP(opt, device)
    print("using WGAN_GP model")

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # cudnn.benchmark = True

    # preparing the training laoder
    train_loader = torch.utils.data.DataLoader(
        ImageLoader(
            opt.img_path,
            transforms.Compose([
                transforms.Scale(
                    128
                ),  # rescale the image keeping the original aspect ratio
                transforms.CenterCrop(
                    128),  # we get only the center of that rescaled
                transforms.RandomCrop(
                    128),  # random crop within the center crop 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
            data_path=opt.data_path,
            partition='train'),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.workers,
        pin_memory=True)
    print('Training loader prepared.')

    # preparing validation loader
    val_loader = torch.utils.data.DataLoader(
        ImageLoader(
            opt.img_path,
            transforms.Compose([
                transforms.Scale(
                    128
                ),  # rescale the image keeping the original aspect ratio
                transforms.CenterCrop(
                    128),  # we get only the center of that rescaled
                transforms.ToTensor(),
                normalize,
            ]),
            data_path=opt.data_path,
            partition='val'),
        batch_size=opt.batch_size,
        shuffle=False,
        num_workers=opt.workers,
        pin_memory=True)
    print('Validation loader prepared.')

    # Start model training
    if opt.is_train == 'True':
        model.train(train_loader)
    else:
        print("Done!")
Exemplo n.º 2
0
def main(args):
    model = None
    if args.model == 'GAN':
        model = GAN(args)
    elif args.model == 'DCGAN':
        model = DCGAN_MODEL(args)
    elif args.model == 'WGAN-CP':
        model = WGAN_CP(args)
    elif args.model == 'WGAN-GP':
        model = WGAN_GP(args)
    else:
        print("Model type non-existing. Try again.")
        exit(-1)

    # Load datasets to train and test loaders
    train_loader, test_loader = get_data_loader(args)
    # feature_extraction = FeatureExtractionTest(train_loader, test_loader, args.cuda, args.batch_size)

    # Start model training
    if args.is_train == 'True':
        model.train(train_loader)

    # start evaluating on test data
    else:
        model.evaluate(test_loader, args.load_D, args.load_G)