Пример #1
0
def PGAN(pretrained=False, *args, **kwargs):
    """
    Progressive growing model
    pretrained (bool): load a pretrained model ?
    model_name (string): if pretrained, load one of the following models
    celebaHQ-256, celebaHQ-512, DTD, celeba, cifar10. Default is celebaHQ.
    """
    from models.progressive_gan import ProgressiveGAN as PGAN
    if 'config' not in kwargs or kwargs['config'] is None:
        kwargs['config'] = {}

    model = PGAN(useGPU=kwargs.get('useGPU', True),
                 storeAVG=True,
                 **kwargs['config'])

    checkpoint = {"celebAHQ-256": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ_s6_i80000-6196db68.pth',
                  "celebAHQ-512": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ16_december_s7_i96000-9c72988c.pth',
                  "DTD": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/testDTD_s5_i96000-04efa39f.pth',
                  "celeba": "https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaCropped_s5_i83000-2b0acc76.pth"}
    if pretrained:
        if "model_name" in kwargs:
            if kwargs["model_name"] not in checkpoint.keys():
                raise ValueError("model_name should be in "
                                    + str(checkpoint.keys()))
        else:
            print("Loading default model : celebaHQ-256")
            kwargs["model_name"] = "celebAHQ-256"
        state_dict = model_zoo.load_url(checkpoint[kwargs["model_name"]],
                                        map_location='cpu')
        model.load_state_dict(state_dict)
    return model
def load_pretrained_PGAN(dataset, project_path):
    use_gpu = True if torch.cuda.is_available() else False

    if(not use_gpu):
        raise ValueError("You should use GPU.")

    model, state_dict = PGAN(model_name=dataset, pretrained=True, useGPU=use_gpu, current_path=project_path)

    netG = model.getOriginalG()
    utils.loadStateDictCompatible(netG, state_dict['netG'])

    return model, netG
def PGAN(pretrained=False, *args, **kwargs):
    """
    Progressive growing model
    pretrained (bool): load a pretrained model ?
    model_name (string): if pretrained, load one of the following models
    celebaHQ-256, celebaHQ-512, DTD, celeba, cifar10. Default is celebaHQ.
    """

    current_path = kwargs["current_path"]

    from models.progressive_gan import ProgressiveGAN as PGAN
    if 'config' not in kwargs or kwargs['config'] is None:
        kwargs['config'] = {}

    model = PGAN(useGPU=kwargs.get('useGPU', True),
                 storeAVG=True,
                 **kwargs['config'])

    checkpoint = {"celebAHQ_256": current_path + '/weight/celebaHQ_256.pth',
                  "celebAHQ_512": current_path + '/weight/celebaHQ_512.pth',
                  "DTD": current_path + '/weight/DTD.pth',
                  "celeba_cropped": current_path + '/weight/generator.pth'} #Actually this is celeba cropped

    if pretrained:
        if "model_name" in kwargs:
            if kwargs["model_name"] not in checkpoint.keys():
                raise ValueError("model_name should be in "
                                    + str(checkpoint.keys()))
        else:
            print("Loading default model : celebaHQ-256")
            kwargs["model_name"] = "celebAHQ-256"

        #state_dict = model_zoo.load_url(checkpoint[kwargs["model_name"]], map_location='cpu')
        state_dict = torch.load(checkpoint[kwargs["model_name"]], map_location='cuda')
        model.load_state_dict(state_dict)
    return model, state_dict
Пример #4
0
# -*- coding: utf-8 -*-
from models.progressive_gan import ProgressiveGAN as PGAN

model = PGAN(useGPU=True, storeAVG=True, {})

state_dict = torch.load('celebaCropped_s5_i83000-2b0acc76.pth',
                        map_location='cuda:0')
model.load_state_dict(state_dict)