def PGAN(pretrained=False, *args, **kwargs): """ Progressive growing model pretrained (bool): load a pretrained model ? model_name (string): if pretrained, load one of the following models celebaHQ-256, celebaHQ-512, DTD, celeba, cifar10. Default is celebaHQ. """ from models.progressive_gan import ProgressiveGAN as PGAN if 'config' not in kwargs or kwargs['config'] is None: kwargs['config'] = {} model = PGAN(useGPU=kwargs.get('useGPU', True), storeAVG=True, **kwargs['config']) checkpoint = {"celebAHQ-256": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ_s6_i80000-6196db68.pth', "celebAHQ-512": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ16_december_s7_i96000-9c72988c.pth', "DTD": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/testDTD_s5_i96000-04efa39f.pth', "celeba": "https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaCropped_s5_i83000-2b0acc76.pth"} if pretrained: if "model_name" in kwargs: if kwargs["model_name"] not in checkpoint.keys(): raise ValueError("model_name should be in " + str(checkpoint.keys())) else: print("Loading default model : celebaHQ-256") kwargs["model_name"] = "celebAHQ-256" state_dict = model_zoo.load_url(checkpoint[kwargs["model_name"]], map_location='cpu') model.load_state_dict(state_dict) return model
def load_pretrained_PGAN(dataset, project_path): use_gpu = True if torch.cuda.is_available() else False if(not use_gpu): raise ValueError("You should use GPU.") model, state_dict = PGAN(model_name=dataset, pretrained=True, useGPU=use_gpu, current_path=project_path) netG = model.getOriginalG() utils.loadStateDictCompatible(netG, state_dict['netG']) return model, netG
def __init__(self, name, pretrained_netI=False): """ :param name: name of the model: object, sis or mps """ self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.name = name self.maximums = { "sis": torch.tensor([1., 0.242726, 2.76592]).to(self.device), "object": torch.tensor([1., 0.22169, 2.41602]).to(self.device), "mps": torch.tensor([1., 0.261072, 3.07107]).to(self.device) } with open('files/{}.json'.format(self.name)) as json_file: data = json.load(json_file) self.means = torch.tensor(data['means']).to(self.device) self.stds = torch.tensor(data['stds']).to(self.device) model = ProGAN() model.load(path='files/{}.pt'.format(self.name), loadD=False) self.netG = model.netG for p in self.netG.parameters(): p.requires_grad_(False) self.netG.eval() self.netI = FC_selu_first(input_size=nw, output_size=nz, hidden_layer_size=hidden_layer_size, num_extra_layers=num_extra_layers).to( self.device) if pretrained_netI: self.netI.load_state_dict( torch.load('files/netI_{}.pt'.format(self.name), map_location=torch.device('cpu'))) self.netI.eval()
def PGAN(pretrained=False, *args, **kwargs): """ Progressive growing model pretrained (bool): load a pretrained model ? model_name (string): if pretrained, load one of the following models celebaHQ-256, celebaHQ-512, DTD, celeba, cifar10. Default is celebaHQ. """ current_path = kwargs["current_path"] from models.progressive_gan import ProgressiveGAN as PGAN if 'config' not in kwargs or kwargs['config'] is None: kwargs['config'] = {} model = PGAN(useGPU=kwargs.get('useGPU', True), storeAVG=True, **kwargs['config']) checkpoint = {"celebAHQ_256": current_path + '/weight/celebaHQ_256.pth', "celebAHQ_512": current_path + '/weight/celebaHQ_512.pth', "DTD": current_path + '/weight/DTD.pth', "celeba_cropped": current_path + '/weight/generator.pth'} #Actually this is celeba cropped if pretrained: if "model_name" in kwargs: if kwargs["model_name"] not in checkpoint.keys(): raise ValueError("model_name should be in " + str(checkpoint.keys())) else: print("Loading default model : celebaHQ-256") kwargs["model_name"] = "celebAHQ-256" #state_dict = model_zoo.load_url(checkpoint[kwargs["model_name"]], map_location='cpu') state_dict = torch.load(checkpoint[kwargs["model_name"]], map_location='cuda') model.load_state_dict(state_dict) return model, state_dict
# -*- coding: utf-8 -*- from models.progressive_gan import ProgressiveGAN as PGAN model = PGAN(useGPU=True, storeAVG=True, {}) state_dict = torch.load('celebaCropped_s5_i83000-2b0acc76.pth', map_location='cuda:0') model.load_state_dict(state_dict)