Пример #1
0
def PGAN(pretrained=False, *args, **kwargs):
    """
    Progressive growing model
    pretrained (bool): load a pretrained model ?
    model_name (string): if pretrained, load one of the following models
    celebaHQ-256, celebaHQ-512, DTD, celeba, cifar10. Default is celebaHQ.
    """
    from models.progressive_gan import ProgressiveGAN as PGAN
    if 'config' not in kwargs or kwargs['config'] is None:
        kwargs['config'] = {}

    model = PGAN(useGPU=kwargs.get('useGPU', True),
                 storeAVG=True,
                 **kwargs['config'])

    checkpoint = {"celebAHQ-256": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ_s6_i80000-6196db68.pth',
                  "celebAHQ-512": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ16_december_s7_i96000-9c72988c.pth',
                  "DTD": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/testDTD_s5_i96000-04efa39f.pth',
                  "celeba": "https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaCropped_s5_i83000-2b0acc76.pth"}
    if pretrained:
        if "model_name" in kwargs:
            if kwargs["model_name"] not in checkpoint.keys():
                raise ValueError("model_name should be in "
                                    + str(checkpoint.keys()))
        else:
            print("Loading default model : celebaHQ-256")
            kwargs["model_name"] = "celebAHQ-256"
        state_dict = model_zoo.load_url(checkpoint[kwargs["model_name"]],
                                        map_location='cpu')
        model.load_state_dict(state_dict)
    return model
def load_pretrained_PGAN(dataset, project_path):
    use_gpu = True if torch.cuda.is_available() else False

    if(not use_gpu):
        raise ValueError("You should use GPU.")

    model, state_dict = PGAN(model_name=dataset, pretrained=True, useGPU=use_gpu, current_path=project_path)

    netG = model.getOriginalG()
    utils.loadStateDictCompatible(netG, state_dict['netG'])

    return model, netG
Пример #3
0
    def __init__(self, name, pretrained_netI=False):
        """

        :param name: name of the model: object, sis or mps
        """
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.name = name
        self.maximums = {
            "sis": torch.tensor([1., 0.242726, 2.76592]).to(self.device),
            "object": torch.tensor([1., 0.22169, 2.41602]).to(self.device),
            "mps": torch.tensor([1., 0.261072, 3.07107]).to(self.device)
        }

        with open('files/{}.json'.format(self.name)) as json_file:
            data = json.load(json_file)
            self.means = torch.tensor(data['means']).to(self.device)
            self.stds = torch.tensor(data['stds']).to(self.device)

        model = ProGAN()
        model.load(path='files/{}.pt'.format(self.name), loadD=False)
        self.netG = model.netG

        for p in self.netG.parameters():
            p.requires_grad_(False)

        self.netG.eval()

        self.netI = FC_selu_first(input_size=nw,
                                  output_size=nz,
                                  hidden_layer_size=hidden_layer_size,
                                  num_extra_layers=num_extra_layers).to(
                                      self.device)
        if pretrained_netI:
            self.netI.load_state_dict(
                torch.load('files/netI_{}.pt'.format(self.name),
                           map_location=torch.device('cpu')))
            self.netI.eval()
def PGAN(pretrained=False, *args, **kwargs):
    """
    Progressive growing model
    pretrained (bool): load a pretrained model ?
    model_name (string): if pretrained, load one of the following models
    celebaHQ-256, celebaHQ-512, DTD, celeba, cifar10. Default is celebaHQ.
    """

    current_path = kwargs["current_path"]

    from models.progressive_gan import ProgressiveGAN as PGAN
    if 'config' not in kwargs or kwargs['config'] is None:
        kwargs['config'] = {}

    model = PGAN(useGPU=kwargs.get('useGPU', True),
                 storeAVG=True,
                 **kwargs['config'])

    checkpoint = {"celebAHQ_256": current_path + '/weight/celebaHQ_256.pth',
                  "celebAHQ_512": current_path + '/weight/celebaHQ_512.pth',
                  "DTD": current_path + '/weight/DTD.pth',
                  "celeba_cropped": current_path + '/weight/generator.pth'} #Actually this is celeba cropped

    if pretrained:
        if "model_name" in kwargs:
            if kwargs["model_name"] not in checkpoint.keys():
                raise ValueError("model_name should be in "
                                    + str(checkpoint.keys()))
        else:
            print("Loading default model : celebaHQ-256")
            kwargs["model_name"] = "celebAHQ-256"

        #state_dict = model_zoo.load_url(checkpoint[kwargs["model_name"]], map_location='cpu')
        state_dict = torch.load(checkpoint[kwargs["model_name"]], map_location='cuda')
        model.load_state_dict(state_dict)
    return model, state_dict
Пример #5
0
# -*- coding: utf-8 -*-
from models.progressive_gan import ProgressiveGAN as PGAN

model = PGAN(useGPU=True, storeAVG=True, {})

state_dict = torch.load('celebaCropped_s5_i83000-2b0acc76.pth',
                        map_location='cuda:0')
model.load_state_dict(state_dict)