def get_BigGAN(version="biggan-deep-256"): cache_path = "/scratch/binxu/torch/" cfg = BigGANConfig.from_json_file( join(cache_path, "%s-config.json" % version)) BGAN = BigGAN(cfg) BGAN.load_state_dict( torch.load(join(cache_path, "%s-pytorch_model.bin" % version))) return BGAN
def loadBigGAN(version="biggan-deep-256"): from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample, BigGANConfig if platform == "linux": cache_path = "/scratch/binxu/torch/" cfg = BigGANConfig.from_json_file( join(cache_path, "%s-config.json" % version)) BGAN = BigGAN(cfg) BGAN.load_state_dict( torch.load(join(cache_path, "%s-pytorch_model.bin" % version))) else: BGAN = BigGAN.from_pretrained(version) for param in BGAN.parameters(): param.requires_grad_(False) # embed_mat = BGAN.embeddings.parameters().__next__().data BGAN.cuda() return BGAN