Exemplo n.º 1
0
def calculate_data_statics(path, save_path, batch_size, cuda, dims):
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx])
    if cuda:
        model.cuda()
    m, s = _compute_statistics_of_path(path, model, batch_size, dims, cuda)
    np.savez(os.path.join(save_path, 'celeba_test.npz'), mu=m, sigma=s)
Exemplo n.º 2
0
def calculate_kid_given_paths(paths, batch_size, cuda, dims, model_type='inception'):
    """Calculates the KID of two paths"""
    pths = []
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)
        if os.path.isdir(p):
            pths.append(p)
        elif p.endswith('.npy'):
            np_imgs = np.load(p)
            if np_imgs.shape[0] > 50000: np_imgs = np_imgs[np.random.permutation(np.arange(np_imgs.shape[0]))][:50000]
            pths.append(np_imgs)

    if model_type == 'inception':
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
        model = InceptionV3([block_idx])
    elif model_type == 'lenet':
        model = LeNet5()
        model.load_state_dict(torch.load('./models/lenet.pth'))
    if cuda:
       model.cuda()

    act_true = _compute_activations(pths[0], model, batch_size, dims, cuda, model_type)
    pths = pths[1:]
    results = []
    for j, pth in enumerate(pths):
        print(paths[j+1])
        actj = _compute_activations(pth, model, batch_size, dims, cuda, model_type)
        kid_values = polynomial_mmd_averages(act_true, actj, n_subsets=100)
        results.append((paths[j+1], kid_values[0].mean(), kid_values[0].std()))
    return results
Exemplo n.º 3
0
    def __init__(self, hparams):
        super().__init__()

        print("Initializing our HandGAN model...")

        # Workaround from https://github.com/PyTorchLightning/pytorch-lightning/issues/3998
        # Happens when loading model from checkpoints. save_hyperparameters() not working
        if isinstance(hparams, dict):
            hparams = Namespace(**hparams)

        #self.save_hyperparameters()
        self.hparams = hparams

        # Used to initialize the networks
        init = mutils.Initializer(init_type=hparams.init_type,
                                  init_gain=hparams.init_gain)

        # Network architecture
        # Two generators, one for each domain:
        self.g_ab = init(generator(
            hparams))  #  - g_ab: translation from domain A to domain B
        self.g_ba = init(generator(
            hparams))  #  - g_ba: translation from domain B to domain A

        # Discriminators:
        self.d_a = init(
            discriminator(hparams))  #  - d_a: domain A discriminator
        self.d_b = init(
            discriminator(hparams))  #  - d_b: domain B discriminator

        # For the perceptual discriminator we will need a feature extractor
        if hparams.netD == 'perceptual':
            self.vgg_net = Vgg16().eval()

        # For validation we will need Inception network to compute FID metric
        if hparams.valid_interval > 0.0:
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
            self.inception_net = InceptionV3([block_idx]).eval()

        if hparams.ganerated:
            model = SilNet.load_from_checkpoint(
                "./weights/silnet_pretrained.ckpt")
            mutils.set_requires_grad(model, requires_grad=False)
            self.silnet = model.unet.eval()

        # ImagePool from where we randomly get generated images in both domains
        self.fake_a_pool = mutils.ImagePool(hparams.pool_size)
        self.fake_b_pool = mutils.ImagePool(hparams.pool_size)

        # Criterions
        self.crit_cycle = torch.nn.L1Loss()
        self.crit_discr = DiscriminatorLoss('lsgan')

        if hparams.lambda_idt > 0.0:
            self.crit_idt = torch.nn.L1Loss()

        if hparams.ganerated:
            self.crit_geom = torch.nn.BCEWithLogitsLoss()
Exemplo n.º 4
0
def calculate_fid_given_paths(paths,
                              batch_size,
                              cuda,
                              dims,
                              bootstrap=True,
                              n_bootstraps=10,
                              model_type='inception'):
    """Calculates the FID of two paths"""
    pths = []
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)
        if os.path.isdir(p):
            pths.append(p)
        elif p.endswith('.npy'):
            np_imgs = np.load(p)
            if np_imgs.shape[0] > 25000:
                np_imgs = np_imgs[:50000]
            pths.append(np_imgs)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    if model_type == 'inception':
        model = InceptionV3([block_idx])
    elif model_type == 'lenet':
        model = LeNet5()
        model.load_state_dict(
            torch.load('/content/gan-metrics-pytorch/models/lenet.pth'))
    if cuda:
        model.cuda()

    act_true = _compute_activations(pths[0], model, batch_size, dims, cuda,
                                    model_type)
    n_bootstraps = n_bootstraps if bootstrap else 1
    pths = pths[1:]
    results = []
    for j, pth in enumerate(pths):
        print(paths[j + 1])
        actj = _compute_activations(pth, model, batch_size, dims, cuda,
                                    model_type)
        fid_values = np.zeros((n_bootstraps))
        with tqdm(range(n_bootstraps), desc='FID') as bar:
            for i in bar:
                act1_bs = act_true[np.random.choice(act_true.shape[0],
                                                    act_true.shape[0],
                                                    replace=True)]
                act2_bs = actj[np.random.choice(actj.shape[0],
                                                actj.shape[0],
                                                replace=True)]
                m1, s1 = calculate_activation_statistics(act1_bs)
                m2, s2 = calculate_activation_statistics(act2_bs)
                fid_values[i] = calculate_frechet_distance(m1, s1, m2, s2)
                bar.set_postfix({'mean': fid_values[:i + 1].mean()})
        results.append((paths[j + 1], fid_values.mean(), fid_values.std()))
    return results
Exemplo n.º 5
0
def calculate_fid_given_lists(paths, batch_size, cuda, dims):
    """Calculates the FID of two paths"""
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx])
    model.eval()
    if cuda:
        print('cuda')
        model.cuda()

    m1, s1 = _compute_statistics(paths[0], model, batch_size, dims, cuda)
    m2, s2 = _compute_statistics(paths[1], model, batch_size, dims, cuda)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value
Exemplo n.º 6
0
def calculate_fid_given_paths(paths, batch_size, cuda, dims):
    """Calculates the FID of two paths"""
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx])
    if cuda:
        model.cuda()

    m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
                                         dims, cuda)
    m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
                                         dims, cuda)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value
Exemplo n.º 7
0
def build_model(args):
    args.to_train = 'CDGI'

    networks = {}
    opts = {}
    is_semi = (0.0 < args.p_semi < 1.0)
    if is_semi:
        assert 'SEMI' in args.train_mode
    if 'C' in args.to_train:
        networks['C'] = GuidingNet(args.img_size, {'cont': args.sty_dim, 'disc': args.output_k})
        networks['C_EMA'] = GuidingNet(args.img_size, {'cont': args.sty_dim, 'disc': args.output_k})
    if 'D' in args.to_train:
        networks['D'] = Discriminator(args.img_size, num_domains=args.output_k)
    if 'G' in args.to_train:
        networks['G'] = Generator(args.img_size, args.sty_dim, use_sn=False)
        networks['G_EMA'] = Generator(args.img_size, args.sty_dim, use_sn=False)
    if 'I' in args.to_train:
        networks['inceptionNet'] = InceptionV3([InceptionV3.BLOCK_INDEX_BY_DIM[args.dims]])

    if args.distributed:
        if args.gpu is not None:
            print('Distributed to', args.gpu)
            torch.cuda.set_device(args.gpu)
            args.batch_size = int(args.batch_size / args.ngpus_per_node)
            args.workers = int(args.workers / args.ngpus_per_node)
            for name, net in networks.items():
                if name in ['inceptionNet']:
                    continue
                net_tmp = net.cuda(args.gpu)
                networks[name] = torch.nn.parallel.DistributedDataParallel(net_tmp, device_ids=[args.gpu], output_device=args.gpu)
        else:
            for name, net in networks.items():
                net_tmp = net.cuda()
                networks[name] = torch.nn.parallel.DistributedDataParallel(net_tmp)

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        for name, net in networks.items():
            networks[name] = net.cuda(args.gpu)
    else:
        for name, net in networks.items():
            networks[name] = torch.nn.DataParallel(net).cuda()

    if 'C' in args.to_train:
        opts['C'] = torch.optim.Adam(
            networks['C'].module.parameters() if args.distributed else networks['C'].parameters(),
            1e-4, weight_decay=0.001)
        if args.distributed:
            networks['C_EMA'].module.load_state_dict(networks['C'].module.state_dict())
        else:
            networks['C_EMA'].load_state_dict(networks['C'].state_dict())
    if 'D' in args.to_train:
        opts['D'] = torch.optim.RMSprop(
            networks['D'].module.parameters() if args.distributed else networks['D'].parameters(),
            1e-4, weight_decay=0.0001)
    if 'G' in args.to_train:
        opts['G'] = torch.optim.RMSprop(
            networks['G'].module.parameters() if args.distributed else networks['G'].parameters(),
            1e-4, weight_decay=0.0001)

    return networks, opts
Exemplo n.º 8
0
def get_network(name: str, num_classes: int) -> None:
    return \
        AlexNet(
            num_classes=num_classes) if name == 'AlexNet' else\
        DenseNet201(
            num_classes=num_classes) if name == 'DenseNet201' else\
        DenseNet169(
            num_classes=num_classes) if name == 'DenseNet169' else\
        DenseNet161(
            num_classes=num_classes) if name == 'DenseNet161' else\
        DenseNet121(
            num_classes=num_classes) if name == 'DenseNet121' else\
        DenseNet121CIFAR(
            num_classes=num_classes) if name == 'DenseNet121CIFAR' else\
        GoogLeNet(
            num_classes=num_classes) if name == 'GoogLeNet' else\
        InceptionV3(
            num_classes=num_classes) if name == 'InceptionV3' else\
        MNASNet_0_5(
            num_classes=num_classes) if name == 'MNASNet_0_5' else\
        MNASNet_0_75(
            num_classes=num_classes) if name == 'MNASNet_0_75' else\
        MNASNet_1(
            num_classes=num_classes) if name == 'MNASNet_1' else\
        MNASNet_1_3(
            num_classes=num_classes) if name == 'MNASNet_1_3' else\
        MobileNetV2(
            num_classes=num_classes) if name == 'MobileNetV2' else\
        ResNet18(
            num_classes=num_classes) if name == 'ResNet18' else\
        ResNet34(
            num_classes=num_classes) if name == 'ResNet34' else\
        ResNet34CIFAR(
            num_classes=num_classes) if name == 'ResNet34CIFAR' else\
        ResNet50CIFAR(
            num_classes=num_classes) if name == 'ResNet50CIFAR' else\
        ResNet101CIFAR(
            num_classes=num_classes) if name == 'ResNet101CIFAR' else\
        ResNet18CIFAR(
            num_classes=num_classes) if name == 'ResNet18CIFAR' else\
        ResNet50(
            num_classes=num_classes) if name == 'ResNet50' else\
        ResNet101(
            num_classes=num_classes) if name == 'ResNet101' else\
        ResNet152(
            num_classes=num_classes) if name == 'ResNet152' else\
        ResNeXt50(
            num_classes=num_classes) if name == 'ResNext50' else\
        ResNeXtCIFAR(
            num_classes=num_classes) if name == 'ResNeXtCIFAR' else\
        ResNeXt101(
            num_classes=num_classes) if name == 'ResNext101' else\
        WideResNet50(
            num_classes=num_classes) if name == 'WideResNet50' else\
        WideResNet101(
            num_classes=num_classes) if name == 'WideResNet101' else\
        ShuffleNetV2_0_5(
            num_classes=num_classes) if name == 'ShuffleNetV2_0_5' else\
        ShuffleNetV2_1(
            num_classes=num_classes) if name == 'ShuffleNetV2_1' else\
        ShuffleNetV2_1_5(
            num_classes=num_classes) if name == 'ShuffleNetV2_1_5' else\
        ShuffleNetV2_2(
            num_classes=num_classes) if name == 'ShuffleNetV2_2' else\
        SqueezeNet_1(
            num_classes=num_classes) if name == 'SqueezeNet_1' else\
        SqueezeNet_1_1(
            num_classes=num_classes) if name == 'SqueezeNet_1_1' else\
        VGG11(
            num_classes=num_classes) if name == 'VGG11' else\
        VGG11_BN(
            num_classes=num_classes) if name == 'VGG11_BN' else\
        VGG13(
            num_classes=num_classes) if name == 'VGG13' else\
        VGG13_BN(
            num_classes=num_classes) if name == 'VGG13_BN' else\
        VGG16(
            num_classes=num_classes) if name == 'VGG16' else\
        VGG16_BN(
            num_classes=num_classes) if name == 'VGG16_BN' else\
        VGG19(
            num_classes=num_classes) if name == 'VGG19' else\
        VGG19_BN(
            num_classes=num_classes) if name == 'VGG19_BN' else \
        VGGCIFAR('VGG16',
                 num_classes=num_classes) if name == 'VGG16CIFAR' else \
        EfficientNetB4(
            num_classes=num_classes) if name == 'EfficientNetB4' else \
        EfficientNetB0CIFAR(
            num_classes=num_classes) if name == 'EfficientNetB0CIFAR' else\
        None
Exemplo n.º 9
0
    #================================
    model_G = Pix2PixHDGenerator().to(device)
    model_D = PatchGANDiscriminator(in_dim=3 + 3, n_fmaps=64).to(device)

    # モデルを読み込む
    if not args.load_checkpoints_path == '' and os.path.exists(
            args.load_checkpoints_path):
        load_checkpoint(model_G, device, args.load_checkpoints_path)

    if (args.debug):
        print("model_G\n", model_G)
        print("model_D\n", model_D)

    # Inception モデル / FID スコアの計算用
    if (args.diaplay_scores):
        inception = InceptionV3().to(device)

    #================================
    # optimizer_G の設定
    #================================
    optimizer_G = optim.Adam(params=model_G.parameters(),
                             lr=args.lr,
                             betas=(args.beta1, args.beta2))
    optimizer_D = optim.Adam(params=model_D.parameters(),
                             lr=args.lr,
                             betas=(args.beta1, args.beta2))

    #================================
    # loss 関数の設定
    #================================
    loss_l1_fn = nn.L1Loss()